Use structured outputs with Pydantic models instead of text parsing
This commit is contained in:
@@ -290,7 +290,11 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
return ClaudeSDKClient(options=options)
|
return ClaudeSDKClient(options=options)
|
||||||
|
|
||||||
def _build_prompt(self, input_value: str) -> str:
|
def _build_prompt(self, input_value: str) -> str:
|
||||||
"""Build prompt from signature docstring, field descriptions, and input value."""
|
"""Build prompt from signature docstring, field descriptions, and input value.
|
||||||
|
|
||||||
|
Note: When using structured outputs, the SDK handles JSON formatting automatically
|
||||||
|
via the output_format parameter, so we don't add JSON instructions to the prompt.
|
||||||
|
"""
|
||||||
prompt_parts = []
|
prompt_parts = []
|
||||||
|
|
||||||
# add signature docstring if present
|
# add signature docstring if present
|
||||||
@@ -325,20 +329,18 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
if output_desc:
|
if output_desc:
|
||||||
prompt_parts.append(f"\nPlease produce the following output: {output_desc}")
|
prompt_parts.append(f"\nPlease produce the following output: {output_desc}")
|
||||||
|
|
||||||
# for Pydantic outputs, add explicit JSON instructions
|
# Don't add JSON instructions - the SDK handles structured outputs via output_format
|
||||||
if self.output_format:
|
# The schema is passed through ClaudeAgentOptions and enforced by the SDK
|
||||||
schema = self.output_format["schema"]
|
|
||||||
prompt_parts.append(
|
|
||||||
f"\nYou MUST respond with ONLY valid JSON matching this schema:\n"
|
|
||||||
f"{json.dumps(schema, indent=2)}\n\n"
|
|
||||||
f"Do not include any explanatory text, markdown formatting, or code blocks. "
|
|
||||||
f"Return ONLY the raw JSON object."
|
|
||||||
)
|
|
||||||
|
|
||||||
return "\n\n".join(prompt_parts)
|
return "\n\n".join(prompt_parts)
|
||||||
|
|
||||||
def _get_output_format(self) -> Optional[dict[str, Any]]:
|
def _get_output_format(self) -> Optional[dict[str, Any]]:
|
||||||
"""Get output format configuration for structured outputs."""
|
"""Get output format configuration for structured outputs.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Direct Pydantic models: MyModel
|
||||||
|
- Generic types: list[MyModel], dict[str, MyModel]
|
||||||
|
"""
|
||||||
output_type = self.output_field.annotation
|
output_type = self.output_field.annotation
|
||||||
|
|
||||||
if is_pydantic_model(output_type):
|
if is_pydantic_model(output_type):
|
||||||
@@ -350,43 +352,68 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _run_async(self, prompt: str) -> tuple[str, list[TraceItem], Usage]:
|
async def _run_async(
|
||||||
"""Run the agent asynchronously and collect results."""
|
self, prompt: str
|
||||||
|
) -> tuple[str | dict | list | None, list[TraceItem], Usage]:
|
||||||
|
"""Run the agent asynchronously and collect results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- response: For structured outputs, returns dict/list from structured_output.
|
||||||
|
For text outputs, returns string from result or text blocks.
|
||||||
|
- trace: Execution trace items
|
||||||
|
- usage: Token usage statistics
|
||||||
|
"""
|
||||||
|
print(f"[ClaudeCode._run_async] Initializing client (connected={self._is_connected})")
|
||||||
|
|
||||||
# create client if needed
|
# create client if needed
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
|
print(f"[ClaudeCode._run_async] Creating new ClaudeSDKClient")
|
||||||
self._client = self._create_client()
|
self._client = self._create_client()
|
||||||
|
|
||||||
# connect if not already connected
|
# connect if not already connected
|
||||||
if not self._is_connected:
|
if not self._is_connected:
|
||||||
|
print(f"[ClaudeCode._run_async] Connecting to Claude SDK...")
|
||||||
await self._client.connect()
|
await self._client.connect()
|
||||||
self._is_connected = True
|
self._is_connected = True
|
||||||
|
print(f"[ClaudeCode._run_async] Connected successfully")
|
||||||
|
|
||||||
# send query (output_format already configured in options)
|
# send query (output_format already configured in options)
|
||||||
|
print(f"[ClaudeCode._run_async] Sending query to agent...")
|
||||||
await self._client.query(prompt)
|
await self._client.query(prompt)
|
||||||
|
print(f"[ClaudeCode._run_async] Query sent, waiting for response...")
|
||||||
|
|
||||||
# collect messages and build trace
|
# collect messages and build trace
|
||||||
trace: list[TraceItem] = []
|
trace: list[TraceItem] = []
|
||||||
usage = Usage()
|
usage = Usage()
|
||||||
response_text = ""
|
response_text = ""
|
||||||
|
structured_output = None
|
||||||
|
message_count = 0
|
||||||
|
|
||||||
async for message in self._client.receive_response():
|
async for message in self._client.receive_response():
|
||||||
|
message_count += 1
|
||||||
|
|
||||||
# handle assistant messages
|
# handle assistant messages
|
||||||
if isinstance(message, AssistantMessage):
|
if isinstance(message, AssistantMessage):
|
||||||
|
print(f"[ClaudeCode._run_async] Received AssistantMessage #{message_count} with {len(message.content)} blocks")
|
||||||
for block in message.content:
|
for block in message.content:
|
||||||
if isinstance(block, TextBlock):
|
if isinstance(block, TextBlock):
|
||||||
|
print(f"[ClaudeCode._run_async] - TextBlock: {len(block.text)} chars")
|
||||||
response_text += block.text
|
response_text += block.text
|
||||||
trace.append(
|
trace.append(
|
||||||
AgentMessageItem(text=block.text, model=message.model)
|
AgentMessageItem(text=block.text, model=message.model)
|
||||||
)
|
)
|
||||||
elif isinstance(block, ThinkingBlock):
|
elif isinstance(block, ThinkingBlock):
|
||||||
|
print(f"[ClaudeCode._run_async] - ThinkingBlock: {len(block.thinking)} chars")
|
||||||
trace.append(
|
trace.append(
|
||||||
ThinkingItem(text=block.thinking, model=message.model)
|
ThinkingItem(text=block.thinking, model=message.model)
|
||||||
)
|
)
|
||||||
elif isinstance(block, ToolUseBlock):
|
elif isinstance(block, ToolUseBlock):
|
||||||
# Handle StructuredOutput tool (contains JSON response)
|
print(f"[ClaudeCode._run_async] - ToolUseBlock: {block.name} (id={block.id})")
|
||||||
|
# handle StructuredOutput tool (contains JSON response)
|
||||||
if block.name == "StructuredOutput":
|
if block.name == "StructuredOutput":
|
||||||
# The JSON is directly in the tool input (already a dict)
|
# the JSON is directly in the tool input (already a dict)
|
||||||
response_text = json.dumps(block.input)
|
response_text = json.dumps(block.input)
|
||||||
|
print(f"[ClaudeCode._run_async] StructuredOutput captured ({len(response_text)} chars)")
|
||||||
|
|
||||||
trace.append(
|
trace.append(
|
||||||
ToolUseItem(
|
ToolUseItem(
|
||||||
@@ -396,11 +423,12 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(block, ToolResultBlock):
|
elif isinstance(block, ToolResultBlock):
|
||||||
|
print(f"[ClaudeCode._run_async] - ToolResultBlock: tool_use_id={block.tool_use_id}, is_error={block.is_error}")
|
||||||
content_str = ""
|
content_str = ""
|
||||||
if isinstance(block.content, str):
|
if isinstance(block.content, str):
|
||||||
content_str = block.content
|
content_str = block.content
|
||||||
elif isinstance(block.content, list):
|
elif isinstance(block.content, list):
|
||||||
# Extract text from content blocks
|
# extract text from content blocks
|
||||||
for item in block.content:
|
for item in block.content:
|
||||||
if (
|
if (
|
||||||
isinstance(item, dict)
|
isinstance(item, dict)
|
||||||
@@ -410,7 +438,7 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
|
|
||||||
trace.append(
|
trace.append(
|
||||||
ToolResultItem(
|
ToolResultItem(
|
||||||
tool_name="", # Tool name not in ToolResultBlock
|
tool_name="", # tool name not in ToolResultBlock
|
||||||
tool_use_id=block.tool_use_id,
|
tool_use_id=block.tool_use_id,
|
||||||
content=content_str,
|
content=content_str,
|
||||||
is_error=block.is_error or False,
|
is_error=block.is_error or False,
|
||||||
@@ -419,9 +447,11 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
|
|
||||||
# handle result messages (final message with usage info)
|
# handle result messages (final message with usage info)
|
||||||
elif isinstance(message, ResultMessage):
|
elif isinstance(message, ResultMessage):
|
||||||
|
print(f"[ClaudeCode._run_async] Received ResultMessage (is_error={getattr(message, 'is_error', False)})")
|
||||||
# store session ID
|
# store session ID
|
||||||
if hasattr(message, "session_id"):
|
if hasattr(message, "session_id"):
|
||||||
self._session_id = message.session_id
|
self._session_id = message.session_id
|
||||||
|
print(f"[ClaudeCode._run_async] - Session ID: {self._session_id}")
|
||||||
|
|
||||||
# extract usage
|
# extract usage
|
||||||
if hasattr(message, "usage") and message.usage:
|
if hasattr(message, "usage") and message.usage:
|
||||||
@@ -433,6 +463,7 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
),
|
),
|
||||||
output_tokens=usage_data.get("output_tokens", 0),
|
output_tokens=usage_data.get("output_tokens", 0),
|
||||||
)
|
)
|
||||||
|
print(f"[ClaudeCode._run_async] - Usage: {usage.input_tokens} in, {usage.output_tokens} out, {usage.cached_input_tokens} cached")
|
||||||
|
|
||||||
# check for errors
|
# check for errors
|
||||||
if hasattr(message, "is_error") and message.is_error:
|
if hasattr(message, "is_error") and message.is_error:
|
||||||
@@ -441,26 +472,41 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
if hasattr(message, "result")
|
if hasattr(message, "result")
|
||||||
else "Unknown error"
|
else "Unknown error"
|
||||||
)
|
)
|
||||||
|
print(f"[ClaudeCode._run_async] - ERROR: {error_msg}")
|
||||||
trace.append(
|
trace.append(
|
||||||
ErrorItem(message=error_msg, error_type="execution_error")
|
ErrorItem(message=error_msg, error_type="execution_error")
|
||||||
)
|
)
|
||||||
raise RuntimeError(f"Agent execution failed: {error_msg}")
|
raise RuntimeError(f"Agent execution failed: {error_msg}")
|
||||||
|
|
||||||
# extract result if present (for structured outputs from result field)
|
# Prefer structured_output over result (when using output_format)
|
||||||
# Note: structured outputs may come from StructuredOutput tool instead
|
if hasattr(message, "structured_output") and message.structured_output is not None:
|
||||||
if hasattr(message, "result") and message.result:
|
structured_output = message.structured_output
|
||||||
|
print(f"[ClaudeCode._run_async] - Structured output captured: {type(structured_output).__name__} ({len(str(structured_output))} chars)")
|
||||||
|
# Fallback to result field for text outputs
|
||||||
|
elif hasattr(message, "result") and message.result:
|
||||||
response_text = message.result
|
response_text = message.result
|
||||||
|
print(f"[ClaudeCode._run_async] - Result extracted from message ({len(response_text)} chars)")
|
||||||
|
|
||||||
# handle system messages
|
# handle system messages
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
|
print(f"[ClaudeCode._run_async] Received SystemMessage")
|
||||||
# log system messages to trace but don't error
|
# log system messages to trace but don't error
|
||||||
if hasattr(message, "data") and message.data:
|
if hasattr(message, "data") and message.data:
|
||||||
data_str = str(message.data)
|
data_str = str(message.data)
|
||||||
|
print(f"[ClaudeCode._run_async] - Data: {data_str[:100]}..." if len(data_str) > 100 else f"[ClaudeCode._run_async] - Data: {data_str}")
|
||||||
trace.append(
|
trace.append(
|
||||||
AgentMessageItem(text=f"[System: {data_str}]", model="system")
|
AgentMessageItem(text=f"[System: {data_str}]", model="system")
|
||||||
)
|
)
|
||||||
|
|
||||||
return response_text, trace, usage
|
print(f"[ClaudeCode._run_async] Completed: {message_count} messages processed, {len(trace)} trace items")
|
||||||
|
|
||||||
|
# Return structured_output if available (for Pydantic outputs), otherwise text
|
||||||
|
if structured_output is not None:
|
||||||
|
print(f"[ClaudeCode._run_async] Returning structured output")
|
||||||
|
return structured_output, trace, usage
|
||||||
|
else:
|
||||||
|
print(f"[ClaudeCode._run_async] Returning text response")
|
||||||
|
return response_text, trace, usage
|
||||||
|
|
||||||
def forward(self, **kwargs: Any) -> Prediction:
|
def forward(self, **kwargs: Any) -> Prediction:
|
||||||
"""Execute the agent with an input message.
|
"""Execute the agent with an input message.
|
||||||
@@ -480,6 +526,8 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
>>> print(result.trace) # List of execution items
|
>>> print(result.trace) # List of execution items
|
||||||
>>> print(result.usage) # Token usage stats
|
>>> print(result.usage) # Token usage stats
|
||||||
"""
|
"""
|
||||||
|
print(f"\n[ClaudeCode.forward] Called with fields: {list(kwargs.keys())}")
|
||||||
|
|
||||||
# extract input value
|
# extract input value
|
||||||
if self.input_field_name not in kwargs:
|
if self.input_field_name not in kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -488,11 +536,14 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
)
|
)
|
||||||
|
|
||||||
input_value = kwargs[self.input_field_name]
|
input_value = kwargs[self.input_field_name]
|
||||||
|
print(f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value[:100]}..." if len(str(input_value)) > 100 else f"[ClaudeCode.forward] Input field '{self.input_field_name}': {input_value}")
|
||||||
|
|
||||||
# build prompt
|
# build prompt
|
||||||
prompt = self._build_prompt(input_value)
|
prompt = self._build_prompt(input_value)
|
||||||
|
print(f"[ClaudeCode.forward] Built prompt ({len(prompt)} chars): {prompt[:200]}...")
|
||||||
|
|
||||||
# run async execution in event loop
|
# run async execution in event loop
|
||||||
|
print(f"[ClaudeCode.forward] Starting async execution (model={self.model})")
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if loop.is_running():
|
if loop.is_running():
|
||||||
@@ -511,19 +562,37 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
# no event loop, create one
|
# no event loop, create one
|
||||||
response_text, trace, usage = asyncio.run(self._run_async(prompt))
|
response_text, trace, usage = asyncio.run(self._run_async(prompt))
|
||||||
|
|
||||||
|
# Log response details
|
||||||
|
response_type = type(response_text).__name__
|
||||||
|
response_len = len(str(response_text)) if response_text else 0
|
||||||
|
print(f"[ClaudeCode.forward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)")
|
||||||
|
print(f"[ClaudeCode.forward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached")
|
||||||
|
|
||||||
# parse response based on output type
|
# parse response based on output type
|
||||||
output_type = self.output_field.annotation
|
output_type = self.output_field.annotation
|
||||||
if is_pydantic_model(output_type):
|
if is_pydantic_model(output_type):
|
||||||
|
print(f"[ClaudeCode.forward] Parsing structured output (type: {output_type})")
|
||||||
try:
|
try:
|
||||||
|
# response_text can be dict/list (from structured_output) or str (legacy)
|
||||||
parsed_output = parse_json_response(response_text, output_type)
|
parsed_output = parse_json_response(response_text, output_type)
|
||||||
|
print(f"[ClaudeCode.forward] Successfully parsed structured output")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"[ClaudeCode.forward] ERROR: Failed to parse structured output")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to parse Claude response as {output_type.__name__}: {e}\n"
|
f"Failed to parse Claude response as {output_type}: {e}\n"
|
||||||
|
f"Response type: {type(response_text)}\n"
|
||||||
f"Response: {response_text}"
|
f"Response: {response_text}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
print(f"[ClaudeCode.forward] Extracting text response (output type: {output_type})")
|
||||||
# string output - extract text
|
# string output - extract text
|
||||||
parsed_output = extract_text_from_response(response_text)
|
if isinstance(response_text, str):
|
||||||
|
parsed_output = extract_text_from_response(response_text)
|
||||||
|
else:
|
||||||
|
# Shouldn't happen, but handle gracefully
|
||||||
|
parsed_output = str(response_text)
|
||||||
|
|
||||||
|
print(f"[ClaudeCode.forward] Returning Prediction with session_id={self._session_id}\n")
|
||||||
|
|
||||||
# return prediction with typed output, trace, and usage
|
# return prediction with typed output, trace, and usage
|
||||||
return Prediction(
|
return Prediction(
|
||||||
@@ -545,6 +614,8 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
Returns:
|
Returns:
|
||||||
Prediction with typed output, trace, and usage
|
Prediction with typed output, trace, and usage
|
||||||
"""
|
"""
|
||||||
|
print(f"\n[ClaudeCode.aforward] Called with fields: {list(kwargs.keys())}")
|
||||||
|
|
||||||
# extract input value
|
# extract input value
|
||||||
if self.input_field_name not in kwargs:
|
if self.input_field_name not in kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -553,26 +624,47 @@ class ClaudeCode(PrecompiledProgram):
|
|||||||
)
|
)
|
||||||
|
|
||||||
input_value = kwargs[self.input_field_name]
|
input_value = kwargs[self.input_field_name]
|
||||||
|
print(f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value[:100]}..." if len(str(input_value)) > 100 else f"[ClaudeCode.aforward] Input field '{self.input_field_name}': {input_value}")
|
||||||
|
|
||||||
# build prompt
|
# build prompt
|
||||||
prompt = self._build_prompt(input_value)
|
prompt = self._build_prompt(input_value)
|
||||||
|
print(f"[ClaudeCode.aforward] Built prompt ({len(prompt)} chars)")
|
||||||
|
|
||||||
# run async execution
|
# run async execution
|
||||||
|
print(f"[ClaudeCode.aforward] Starting async execution (model={self.model})")
|
||||||
response_text, trace, usage = await self._run_async(prompt)
|
response_text, trace, usage = await self._run_async(prompt)
|
||||||
|
|
||||||
|
# Log response details
|
||||||
|
response_type = type(response_text).__name__
|
||||||
|
response_len = len(str(response_text)) if response_text else 0
|
||||||
|
print(f"[ClaudeCode.aforward] Received response (type={response_type}, {response_len} chars, {len(trace)} trace items)")
|
||||||
|
print(f"[ClaudeCode.aforward] Token usage: {usage.input_tokens} input, {usage.output_tokens} output, {usage.cached_input_tokens} cached")
|
||||||
|
|
||||||
# parse response based on output type
|
# parse response based on output type
|
||||||
output_type = self.output_field.annotation
|
output_type = self.output_field.annotation
|
||||||
if is_pydantic_model(output_type):
|
if is_pydantic_model(output_type):
|
||||||
|
print(f"[ClaudeCode.aforward] Parsing structured output (type: {output_type})")
|
||||||
try:
|
try:
|
||||||
|
# response_text can be dict/list (from structured_output) or str (legacy)
|
||||||
parsed_output = parse_json_response(response_text, output_type)
|
parsed_output = parse_json_response(response_text, output_type)
|
||||||
|
print(f"[ClaudeCode.aforward] Successfully parsed structured output")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"[ClaudeCode.aforward] ERROR: Failed to parse structured output")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to parse Claude response as {output_type.__name__}: {e}\n"
|
f"Failed to parse Claude response as {output_type}: {e}\n"
|
||||||
|
f"Response type: {type(response_text)}\n"
|
||||||
f"Response: {response_text}"
|
f"Response: {response_text}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
print(f"[ClaudeCode.aforward] Extracting text response (output type: {output_type})")
|
||||||
# string output - extract text
|
# string output - extract text
|
||||||
parsed_output = extract_text_from_response(response_text)
|
if isinstance(response_text, str):
|
||||||
|
parsed_output = extract_text_from_response(response_text)
|
||||||
|
else:
|
||||||
|
# Shouldn't happen, but handle gracefully
|
||||||
|
parsed_output = str(response_text)
|
||||||
|
|
||||||
|
print(f"[ClaudeCode.aforward] Returning Prediction with session_id={self._session_id}\n")
|
||||||
|
|
||||||
# return prediction with typed output, trace, and usage
|
# return prediction with typed output, trace, and usage
|
||||||
return Prediction(
|
return Prediction(
|
||||||
|
|||||||
@@ -27,24 +27,96 @@ class Usage:
|
|||||||
|
|
||||||
|
|
||||||
def is_pydantic_model(type_hint: Any) -> bool:
|
def is_pydantic_model(type_hint: Any) -> bool:
|
||||||
"""Check if a type hint is a Pydantic model."""
|
"""Check if a type hint is a Pydantic model or contains one (e.g., list[Model]).
|
||||||
|
|
||||||
|
Returns True for:
|
||||||
|
- Pydantic models: MyModel
|
||||||
|
- Generic types containing Pydantic: list[MyModel], dict[str, MyModel]
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return isinstance(type_hint, type) and issubclass(type_hint, BaseModel)
|
# Direct Pydantic model
|
||||||
|
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Generic type (list, dict, etc.)
|
||||||
|
origin = get_origin(type_hint)
|
||||||
|
if origin is not None:
|
||||||
|
args = get_args(type_hint)
|
||||||
|
# Check if any type argument is a Pydantic model
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
except TypeError:
|
except TypeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_json_schema(pydantic_model: type[BaseModel]) -> dict[str, Any]:
|
def get_json_schema(type_hint: Any) -> dict[str, Any]:
|
||||||
"""Generate JSON schema from Pydantic model.
|
"""Generate JSON schema from type hint.
|
||||||
|
|
||||||
Sets additionalProperties to false to match Codex behavior.
|
Handles:
|
||||||
|
- Pydantic models: MyModel
|
||||||
|
- Generic types: list[MyModel], dict[str, MyModel]
|
||||||
|
|
||||||
|
Note: Claude API requires root type to be "object" for structured outputs (tools).
|
||||||
|
For list/dict types, we wrap them in an object with a single property.
|
||||||
|
|
||||||
|
Sets additionalProperties to false for all objects.
|
||||||
"""
|
"""
|
||||||
schema = pydantic_model.model_json_schema()
|
origin = get_origin(type_hint)
|
||||||
|
args = get_args(type_hint)
|
||||||
|
|
||||||
# Recursively set additionalProperties: false for all objects
|
# Handle generic types (list, dict, etc.)
|
||||||
|
if origin is list:
|
||||||
|
# list[Model] - wrap in object since API requires root type = "object"
|
||||||
|
# {"type": "object", "properties": {"items": {"type": "array", "items": {...}}}}
|
||||||
|
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
||||||
|
model = args[0]
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": model.model_json_schema()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["items"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported list type: {type_hint}")
|
||||||
|
|
||||||
|
elif origin is dict:
|
||||||
|
# dict[str, Model] - wrap in object since API requires root type = "object"
|
||||||
|
# {"type": "object", "properties": {"values": {"type": "object", "additionalProperties": {...}}}}
|
||||||
|
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
|
||||||
|
model = args[1]
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"values": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": model.model_json_schema()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["values"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dict type: {type_hint}")
|
||||||
|
|
||||||
|
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||||
|
# Direct Pydantic model - already an object
|
||||||
|
schema = type_hint.model_json_schema()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type for structured output: {type_hint}")
|
||||||
|
|
||||||
|
# Recursively set additionalProperties: false for all nested objects
|
||||||
def set_additional_properties(obj: dict[str, Any]) -> None:
|
def set_additional_properties(obj: dict[str, Any]) -> None:
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
if obj.get("type") == "object":
|
if obj.get("type") == "object" and "additionalProperties" not in obj:
|
||||||
obj["additionalProperties"] = False
|
obj["additionalProperties"] = False
|
||||||
for value in obj.values():
|
for value in obj.values():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
@@ -58,14 +130,74 @@ def get_json_schema(pydantic_model: type[BaseModel]) -> dict[str, Any]:
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def parse_json_response(response: str, pydantic_model: type[BaseModel]) -> BaseModel:
|
def parse_json_response(
|
||||||
"""Parse JSON response into Pydantic model.
|
response: str | dict | list, type_hint: Any
|
||||||
|
) -> BaseModel | list[BaseModel] | dict[str, BaseModel]:
|
||||||
|
"""Parse JSON response into typed output.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Pydantic models: MyModel
|
||||||
|
- Generic types: list[MyModel], dict[str, MyModel]
|
||||||
|
|
||||||
|
Note: When schema has list/dict at root, the SDK wraps them in {"items": [...]}
|
||||||
|
or {"values": {...}} because API requires root type = "object".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: JSON string or already-parsed dict/list from structured_output
|
||||||
|
type_hint: The output type annotation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated and typed output
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
json.JSONDecodeError: If response is not valid JSON
|
json.JSONDecodeError: If response string is not valid JSON
|
||||||
pydantic.ValidationError: If JSON doesn't match model schema
|
pydantic.ValidationError: If JSON doesn't match schema
|
||||||
"""
|
"""
|
||||||
return pydantic_model.model_validate_json(response)
|
import json
|
||||||
|
|
||||||
|
origin = get_origin(type_hint)
|
||||||
|
args = get_args(type_hint)
|
||||||
|
|
||||||
|
# Parse string to dict/list if needed
|
||||||
|
if isinstance(response, str):
|
||||||
|
parsed = json.loads(response)
|
||||||
|
else:
|
||||||
|
parsed = response
|
||||||
|
|
||||||
|
# Handle list[Model]
|
||||||
|
if origin is list:
|
||||||
|
if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
|
||||||
|
model = args[0]
|
||||||
|
|
||||||
|
# Unwrap from {"items": [...]} if present (from structured_output)
|
||||||
|
if isinstance(parsed, dict) and "items" in parsed:
|
||||||
|
parsed = parsed["items"]
|
||||||
|
|
||||||
|
if not isinstance(parsed, list):
|
||||||
|
raise ValueError(f"Expected list, got {type(parsed)}")
|
||||||
|
return [model.model_validate(item) for item in parsed]
|
||||||
|
|
||||||
|
# Handle dict[str, Model]
|
||||||
|
elif origin is dict:
|
||||||
|
if len(args) >= 2 and isinstance(args[1], type) and issubclass(args[1], BaseModel):
|
||||||
|
model = args[1]
|
||||||
|
|
||||||
|
# Unwrap from {"values": {...}} if present (from structured_output)
|
||||||
|
if isinstance(parsed, dict) and "values" in parsed:
|
||||||
|
parsed = parsed["values"]
|
||||||
|
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError(f"Expected dict, got {type(parsed)}")
|
||||||
|
return {key: model.model_validate(value) for key, value in parsed.items()}
|
||||||
|
|
||||||
|
# Handle direct Pydantic model
|
||||||
|
elif isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||||
|
if isinstance(response, str):
|
||||||
|
return type_hint.model_validate_json(response)
|
||||||
|
else:
|
||||||
|
return type_hint.model_validate(parsed)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported type for parsing: {type_hint}")
|
||||||
|
|
||||||
|
|
||||||
def extract_text_from_response(response: str) -> str:
|
def extract_text_from_response(response: str) -> str:
|
||||||
|
|||||||
38
main.py
38
main.py
@@ -1,16 +1,21 @@
|
|||||||
from claude_dspy import ClaudeCode, ClaudeCodeConfig
|
from claude_dspy import ClaudeCode, ClaudeCodeConfig
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from modaic import AutoProgram
|
from modaic import AutoProgram
|
||||||
import dspy
|
import dspy
|
||||||
|
|
||||||
|
|
||||||
class Output(BaseModel):
|
class ErrorReport(BaseModel):
|
||||||
files: list[str]
|
error: str = Field(..., description="Error message")
|
||||||
|
timestamp: str = Field(..., description="Timestamp of error")
|
||||||
|
line_located: int | None = Field(..., description="Line number where error occurred")
|
||||||
|
file_located: str | None = Field(..., description="File where error occurred")
|
||||||
|
description: str = Field(..., description="Description of errors")
|
||||||
|
reccomedated_fixes: list[str] = Field(..., description="List of recommended fixes")
|
||||||
|
|
||||||
|
|
||||||
class ClaudeCodeSignature(dspy.Signature):
|
class ClaudeCodeSignature(dspy.Signature):
|
||||||
message: str = dspy.InputField(desc="Request to process")
|
log_file: str = dspy.InputField(desc="Log file to process")
|
||||||
output: Output = dspy.OutputField(desc="List of files modified or created")
|
output: list[ErrorReport] = dspy.OutputField(desc="list of error reports created")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -23,25 +28,32 @@ def main():
|
|||||||
signature=ClaudeCodeSignature,
|
signature=ClaudeCodeSignature,
|
||||||
working_directory=".",
|
working_directory=".",
|
||||||
permission_mode="acceptEdits",
|
permission_mode="acceptEdits",
|
||||||
allowed_tools=["Read", "Bash", "Write"],
|
allowed_tools=["Read", "Write", "Bash"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc.push_to_hub(
|
cc.push_to_hub(
|
||||||
"farouk1/claude-code",
|
"farouk1/claude-code",
|
||||||
with_code=True,
|
with_code=True,
|
||||||
commit_message="Add more functionality to signature description parsing",
|
commit_message="Use structured outputs with Pydantic models instead of text parsing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# agent = AutoProgram.from_precompiled("farouk1/claude-code", signature=ClaudeCodeSignature, working_directory=".", permission_mode="acceptEdits", allowed_tools=["Read", "Write", "Bash"])
|
# agent = AutoProgram.from_precompiled("farouk1/claude-code", signature=ClaudeCodeSignature, working_directory=".", permission_mode="acceptEdits", allowed_tools=["Read", "Write", "Bash"])
|
||||||
|
"""
|
||||||
# Test the agent
|
print("Running Claude Code...")
|
||||||
result = cc(message="create a python program that prints 'Hello, World!' and save it to a file in this directory")
|
result = cc(log_file="log.txt")
|
||||||
print(result.output.files)
|
print("Claude Code finished running!")
|
||||||
print(result.output)
|
print("-" * 50)
|
||||||
print(result.usage)
|
print("Output: ", result.output)
|
||||||
|
print("Output type: ", type(result.output))
|
||||||
|
print("Usage: ", result.usage)
|
||||||
|
print("Output type: ", type(result.output))
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("Trace:")
|
print("Trace:")
|
||||||
for i, item in enumerate(result.trace):
|
for i, item in enumerate(result.trace):
|
||||||
print(f" {i}: {item}")
|
print(f" {i}: {item}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user