(no commit message)
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import dspy
|
||||
from .utils import OPENROUTER_API_BASE, OPENROUTER_API_KEY
|
||||
from .modules import signature_generator
|
||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||
|
||||
class PromptToSignatureConfig(PrecompiledConfig):
|
||||
lm: str = "gpt-5-2025-08-07"
|
||||
refine_lm: str = "gpt-5-2025-08-07"
|
||||
lm: str = "openrouter/anthropic/claude-haiku-4.5"
|
||||
refine_lm: str = "openai/gpt-4o"
|
||||
max_tokens: int = 16000
|
||||
temperature: float = 1.0
|
||||
max_attempts_to_refine: int = 5
|
||||
@@ -27,6 +28,8 @@ class PromptToSignatureAgent(PrecompiledAgent):
|
||||
model=config.lm,
|
||||
max_tokens=config.max_tokens,
|
||||
temperature=config.temperature,
|
||||
api_base=OPENROUTER_API_BASE,
|
||||
api_key=OPENROUTER_API_KEY
|
||||
)
|
||||
refine_lm = dspy.LM(
|
||||
model=config.refine_lm,
|
||||
|
||||
@@ -2,6 +2,7 @@ import dspy
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Literal, Union, get_origin
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from time import time
|
||||
|
||||
class FieldType(str, Enum):
|
||||
STRING = "str"
|
||||
@@ -167,16 +168,18 @@ class SignatureGeneration(dspy.Signature):
|
||||
signature_fields: list[GeneratedField] = dspy.OutputField(
|
||||
desc="""List of input and output fields for the signature.
|
||||
|
||||
CRITICAL RULES:
|
||||
1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema
|
||||
2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs
|
||||
3. For Pydantic models: define ALL nested fields properly in the schema with correct types
|
||||
4. Simple outputs (single values) can use basic types like str, int, bool
|
||||
5. Use Literal types for enumerated values (e.g., severity levels, status codes)
|
||||
CRITICAL RULES:
|
||||
1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema
|
||||
2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs
|
||||
3. For Pydantic models: define ALL nested fields properly in the schema with correct types
|
||||
4. Simple outputs (single values) can use basic types like str, int, bool
|
||||
5. Use Literal types for enumerated values (e.g., severity levels, status codes)
|
||||
|
||||
Examples:
|
||||
- BAD: structured_output: str = "A JSON containing patient data..."
|
||||
- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)"""
|
||||
Examples:
|
||||
- BAD: structured_output: str = "A JSON containing patient data..."
|
||||
- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)
|
||||
|
||||
"""
|
||||
)
|
||||
signature_name: str = dspy.OutputField(
|
||||
desc="Suggested class name for the signature (PascalCase)"
|
||||
@@ -186,11 +189,14 @@ Examples:
|
||||
class SignatureGenerator(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.generator = dspy.ChainOfThought(SignatureGeneration)
|
||||
self.generator = dspy.Predict(SignatureGeneration)
|
||||
|
||||
def forward(self, prompt: str):
|
||||
"""Generate DSPy signature and return raw prediction attributes"""
|
||||
start_time = time()
|
||||
result = self.generator(prompt=prompt)
|
||||
end_time = time()
|
||||
print(f"Signature generation took {end_time - start_time:.2f} seconds in inference.")
|
||||
|
||||
return dspy.Prediction(
|
||||
signature_name=result.signature_name,
|
||||
|
||||
30
agent/utils.py
Normal file
30
agent/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Dict, Any
|
||||
import re
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def to_snake_case(name: str) -> str:
|
||||
"""Convert PascalCase or camelCase string to snake_case."""
|
||||
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
||||
|
||||
|
||||
def save_signature_to_file(
|
||||
result: Dict[str, Any], output_path: str | None = None
|
||||
) -> str:
|
||||
"""Save generated signature code to a Python file"""
|
||||
if not output_path:
|
||||
signature_name = result.get("signature_name", "generated_signature")
|
||||
output_path = f"{to_snake_case(signature_name)}.py"
|
||||
|
||||
if result.get("code"):
|
||||
with open(output_path, "w") as f:
|
||||
f.write(result["code"])
|
||||
return output_path
|
||||
else:
|
||||
raise ValueError("No code to save")
|
||||
|
||||
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
|
||||
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
|
||||
Reference in New Issue
Block a user