(no commit message)

This commit is contained in:
2025-11-04 13:29:58 -05:00
parent 87433012c0
commit 22d47004cd
6 changed files with 143 additions and 22 deletions

View File

@@ -1,10 +1,11 @@
import dspy
from .utils import OPENROUTER_API_BASE, OPENROUTER_API_KEY
from .modules import signature_generator
from modaic import PrecompiledAgent, PrecompiledConfig
class PromptToSignatureConfig(PrecompiledConfig):
lm: str = "gpt-5-2025-08-07"
refine_lm: str = "gpt-5-2025-08-07"
lm: str = "openrouter/anthropic/claude-haiku-4.5"
refine_lm: str = "openai/gpt-4o"
max_tokens: int = 16000
temperature: float = 1.0
max_attempts_to_refine: int = 5
@@ -27,6 +28,8 @@ class PromptToSignatureAgent(PrecompiledAgent):
model=config.lm,
max_tokens=config.max_tokens,
temperature=config.temperature,
api_base=OPENROUTER_API_BASE,
api_key=OPENROUTER_API_KEY
)
refine_lm = dspy.LM(
model=config.refine_lm,

View File

@@ -2,6 +2,7 @@ import dspy
from enum import Enum
from typing import Any, Dict, List, Optional, Literal, Union, get_origin
from pydantic import BaseModel, Field, ValidationError
from time import time
class FieldType(str, Enum):
STRING = "str"
@@ -167,16 +168,18 @@ class SignatureGeneration(dspy.Signature):
signature_fields: list[GeneratedField] = dspy.OutputField(
desc="""List of input and output fields for the signature.
CRITICAL RULES:
1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema
2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs
3. For Pydantic models: define ALL nested fields properly in the schema with correct types
4. Simple outputs (single values) can use basic types like str, int, bool
5. Use Literal types for enumerated values (e.g., severity levels, status codes)
CRITICAL RULES:
1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema
2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs
3. For Pydantic models: define ALL nested fields properly in the schema with correct types
4. Simple outputs (single values) can use basic types like str, int, bool
5. Use Literal types for enumerated values (e.g., severity levels, status codes)
Examples:
- BAD: structured_output: str = "A JSON containing patient data..."
- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)"""
Examples:
- BAD: structured_output: str = "A JSON containing patient data..."
- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)
"""
)
signature_name: str = dspy.OutputField(
desc="Suggested class name for the signature (PascalCase)"
@@ -186,11 +189,14 @@ Examples:
class SignatureGenerator(dspy.Module):
def __init__(self):
super().__init__()
self.generator = dspy.ChainOfThought(SignatureGeneration)
self.generator = dspy.Predict(SignatureGeneration)
def forward(self, prompt: str):
"""Generate DSPy signature and return raw prediction attributes"""
start_time = time()
result = self.generator(prompt=prompt)
end_time = time()
print(f"Signature generation took {end_time - start_time:.2f} seconds in inference.")
return dspy.Prediction(
signature_name=result.signature_name,

30
agent/utils.py Normal file
View File

@@ -0,0 +1,30 @@
from typing import Dict, Any
import re
import os
from dotenv import load_dotenv
load_dotenv()
def to_snake_case(name: str) -> str:
"""Convert PascalCase or camelCase string to snake_case."""
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def save_signature_to_file(
result: Dict[str, Any], output_path: str | None = None
) -> str:
"""Save generated signature code to a Python file"""
if not output_path:
signature_name = result.get("signature_name", "generated_signature")
output_path = f"{to_snake_case(signature_name)}.py"
if result.get("code"):
with open(output_path, "w") as f:
f.write(result["code"])
return output_path
else:
raise ValueError("No code to save")
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")