Add more functionality to signature description parsing
This commit is contained in:
@@ -44,19 +44,61 @@ class ClaudeCodeConfig(PrecompiledConfig):
|
||||
|
||||
model: str = "claude-opus-4-5-20251101"
|
||||
|
||||
|
||||
class ClaudeCodeKwargs(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
"""Arguments for ClaudeCode initialization.
|
||||
|
||||
signature: Any # str | dspy.Signature (validated manually)
|
||||
Matches ClaudeAgentOptions from the SDK with additional DSPy-specific fields.
|
||||
See: https://platform.claude.com/docs/en/agent-sdk/python#claudeagentoptions
|
||||
"""
|
||||
|
||||
# DSPy-specific (required)
|
||||
signature: Any # str | dspy.Signature - validated manually in __init__
|
||||
|
||||
# Authentication
|
||||
api_key: str | None = None
|
||||
|
||||
# Basic configuration
|
||||
working_directory: str = "."
|
||||
permission_mode: str | None = None
|
||||
allowed_tools: list[str] | None = None
|
||||
allowed_tools: list[str] | None = None # Any Claude Code tools
|
||||
disallowed_tools: list[str] | None = None
|
||||
sandbox: dict[str, Any] | None = None
|
||||
system_prompt: str | dict[str, Any] | None = None
|
||||
|
||||
# MCP servers
|
||||
mcp_servers: dict[str, Any] | str | Path | None = None
|
||||
|
||||
# Session management
|
||||
continue_conversation: bool = False
|
||||
resume: str | None = None
|
||||
max_turns: int | None = None
|
||||
fork_session: bool = False
|
||||
|
||||
# Advanced options
|
||||
permission_prompt_tool_name: str | None = None
|
||||
settings: str | None = None
|
||||
add_dirs: list[str | Path] | None = None
|
||||
env: dict[str, str] | None = None
|
||||
extra_args: dict[str, str | None] | None = None
|
||||
max_buffer_size: int | None = None
|
||||
|
||||
# Callbacks and hooks
|
||||
stderr: Any | None = None # Callable[[str], None] - can't type check callables in Pydantic easily
|
||||
can_use_tool: Any | None = None # CanUseTool callback
|
||||
hooks: dict[str, list[dict[str, Any]]] | None = None
|
||||
|
||||
# User and settings
|
||||
user: str | None = None
|
||||
include_partial_messages: bool = False
|
||||
setting_sources: list[str] | None = None # List of "user" | "project" | "local"
|
||||
|
||||
# Subagents and plugins
|
||||
agents: dict[str, dict[str, Any]] | None = None
|
||||
plugins: list[dict[str, Any]] | None = None
|
||||
|
||||
# CLI configuration
|
||||
cli_path: str | Path | None = None
|
||||
|
||||
|
||||
class ClaudeCode(PrecompiledProgram):
|
||||
"""DSPy module that wraps Claude Code SDK.
|
||||
@@ -88,22 +130,14 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
args = ClaudeCodeKwargs(**kwargs)
|
||||
|
||||
# Parse and validate signature
|
||||
signature = args.signature
|
||||
api_key = args.api_key
|
||||
working_directory = args.working_directory
|
||||
permission_mode = args.permission_mode
|
||||
allowed_tools = args.allowed_tools
|
||||
disallowed_tools = args.disallowed_tools
|
||||
sandbox = args.sandbox
|
||||
system_prompt = args.system_prompt
|
||||
|
||||
# parse and validate signature
|
||||
if isinstance(signature, str):
|
||||
self.signature = dspy.Signature(signature)
|
||||
else:
|
||||
self.signature = signature
|
||||
|
||||
# validate signature has exactly 1 input and 1 output
|
||||
# Validate signature has exactly 1 input and 1 output
|
||||
input_fields = list(self.signature.input_fields.keys())
|
||||
output_fields = list(self.signature.output_fields.keys())
|
||||
|
||||
@@ -124,17 +158,51 @@ class ClaudeCode(PrecompiledProgram):
|
||||
self.input_field = self.signature.input_fields[self.input_field_name]
|
||||
self.output_field = self.signature.output_fields[self.output_field_name]
|
||||
|
||||
# store config values
|
||||
self.working_directory = Path(working_directory).resolve()
|
||||
# Store all configuration values
|
||||
self.api_key = args.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.working_directory = Path(args.working_directory).resolve()
|
||||
self.model = config.model
|
||||
self.permission_mode = permission_mode
|
||||
self.allowed_tools = allowed_tools
|
||||
self.disallowed_tools = disallowed_tools
|
||||
self.sandbox = sandbox
|
||||
self.system_prompt = system_prompt
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
# No extra options since all kwargs are parsed by ClaudeCodeKwargs
|
||||
self.extra_options = {}
|
||||
|
||||
# Basic options
|
||||
self.permission_mode = args.permission_mode
|
||||
self.allowed_tools = args.allowed_tools
|
||||
self.disallowed_tools = args.disallowed_tools
|
||||
self.sandbox = args.sandbox
|
||||
self.system_prompt = args.system_prompt
|
||||
|
||||
# MCP servers
|
||||
self.mcp_servers = args.mcp_servers
|
||||
|
||||
# Session management
|
||||
self.continue_conversation = args.continue_conversation
|
||||
self.resume = args.resume
|
||||
self.max_turns = args.max_turns
|
||||
self.fork_session = args.fork_session
|
||||
|
||||
# Advanced options
|
||||
self.permission_prompt_tool_name = args.permission_prompt_tool_name
|
||||
self.settings = args.settings
|
||||
self.add_dirs = args.add_dirs
|
||||
self.env = args.env
|
||||
self.extra_args = args.extra_args
|
||||
self.max_buffer_size = args.max_buffer_size
|
||||
|
||||
# Callbacks and hooks
|
||||
self.stderr = args.stderr
|
||||
self.can_use_tool = args.can_use_tool
|
||||
self.hooks = args.hooks
|
||||
|
||||
# User and settings
|
||||
self.user = args.user
|
||||
self.include_partial_messages = args.include_partial_messages
|
||||
self.setting_sources = args.setting_sources
|
||||
|
||||
# Subagents and plugins
|
||||
self.agents = args.agents
|
||||
self.plugins = args.plugins
|
||||
|
||||
# CLI configuration
|
||||
self.cli_path = args.cli_path
|
||||
|
||||
# determine output format upfront
|
||||
self.output_format = self._get_output_format()
|
||||
@@ -154,19 +222,68 @@ class ClaudeCode(PrecompiledProgram):
|
||||
|
||||
def _create_client(self) -> ClaudeSDKClient:
|
||||
"""Create ClaudeSDKClient with configured options."""
|
||||
options = ClaudeAgentOptions(
|
||||
cwd=str(self.working_directory),
|
||||
model=self.model,
|
||||
permission_mode=self.permission_mode,
|
||||
allowed_tools=self.allowed_tools or [],
|
||||
disallowed_tools=self.disallowed_tools or [],
|
||||
sandbox=self.sandbox,
|
||||
system_prompt=self.system_prompt,
|
||||
output_format=self.output_format, # include output format
|
||||
**self.extra_options,
|
||||
)
|
||||
# Build options dict, only including non-None values
|
||||
options_dict = {
|
||||
"cwd": str(self.working_directory),
|
||||
"model": self.model,
|
||||
"output_format": self.output_format,
|
||||
}
|
||||
|
||||
# set API key if provided
|
||||
# Add optional fields only if they're not None
|
||||
if self.permission_mode is not None:
|
||||
options_dict["permission_mode"] = self.permission_mode
|
||||
if self.allowed_tools is not None:
|
||||
options_dict["allowed_tools"] = self.allowed_tools
|
||||
if self.disallowed_tools is not None:
|
||||
options_dict["disallowed_tools"] = self.disallowed_tools
|
||||
if self.sandbox is not None:
|
||||
options_dict["sandbox"] = self.sandbox
|
||||
if self.system_prompt is not None:
|
||||
options_dict["system_prompt"] = self.system_prompt
|
||||
if self.mcp_servers is not None:
|
||||
options_dict["mcp_servers"] = self.mcp_servers
|
||||
if self.continue_conversation:
|
||||
options_dict["continue_conversation"] = self.continue_conversation
|
||||
if self.resume is not None:
|
||||
options_dict["resume"] = self.resume
|
||||
if self.max_turns is not None:
|
||||
options_dict["max_turns"] = self.max_turns
|
||||
if self.fork_session:
|
||||
options_dict["fork_session"] = self.fork_session
|
||||
if self.permission_prompt_tool_name is not None:
|
||||
options_dict["permission_prompt_tool_name"] = self.permission_prompt_tool_name
|
||||
if self.settings is not None:
|
||||
options_dict["settings"] = self.settings
|
||||
if self.add_dirs is not None:
|
||||
options_dict["add_dirs"] = self.add_dirs
|
||||
if self.env is not None:
|
||||
options_dict["env"] = self.env
|
||||
if self.extra_args is not None:
|
||||
options_dict["extra_args"] = self.extra_args
|
||||
if self.max_buffer_size is not None:
|
||||
options_dict["max_buffer_size"] = self.max_buffer_size
|
||||
if self.stderr is not None:
|
||||
options_dict["stderr"] = self.stderr
|
||||
if self.can_use_tool is not None:
|
||||
options_dict["can_use_tool"] = self.can_use_tool
|
||||
if self.hooks is not None:
|
||||
options_dict["hooks"] = self.hooks
|
||||
if self.user is not None:
|
||||
options_dict["user"] = self.user
|
||||
if self.include_partial_messages:
|
||||
options_dict["include_partial_messages"] = self.include_partial_messages
|
||||
if self.setting_sources is not None:
|
||||
options_dict["setting_sources"] = self.setting_sources
|
||||
if self.agents is not None:
|
||||
options_dict["agents"] = self.agents
|
||||
if self.plugins is not None:
|
||||
options_dict["plugins"] = self.plugins
|
||||
if self.cli_path is not None:
|
||||
options_dict["cli_path"] = self.cli_path
|
||||
|
||||
options = ClaudeAgentOptions(**options_dict)
|
||||
|
||||
# Set API key if provided
|
||||
if self.api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = self.api_key
|
||||
|
||||
|
||||
Reference in New Issue
Block a user