(no commit message)

This commit is contained in:
2025-11-04 13:29:58 -05:00
parent 87433012c0
commit 22d47004cd
6 changed files with 143 additions and 22 deletions

View File

@@ -1,5 +1,5 @@
{ {
"signature_generator.generator.predict": { "signature_generator.generator": {
"traces": [], "traces": [],
"train": [], "train": [],
"demos": [], "demos": [],
@@ -10,17 +10,13 @@
"prefix": "Prompt:", "prefix": "Prompt:",
"description": "Natural language description of the desired functionality." "description": "Natural language description of the desired functionality."
}, },
{
"prefix": "Reasoning: Let's think step by step in order to",
"description": "${reasoning}"
},
{ {
"prefix": "Task Description:", "prefix": "Task Description:",
"description": "Clear description of what the signature aims to accomplish." "description": "Clear description of what the signature aims to accomplish."
}, },
{ {
"prefix": "Signature Fields:", "prefix": "Signature Fields:",
"description": "List of input and output fields for the signature.\n\nCRITICAL RULES:\n1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema\n2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs\n3. For Pydantic models: define ALL nested fields properly in the schema with correct types\n4. Simple outputs (single values) can use basic types like str, int, bool\n5. Use Literal types for enumerated values (e.g., severity levels, status codes)\n\nExamples:\n- BAD: structured_output: str = \"A JSON containing patient data...\"\n- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)" "description": "List of input and output fields for the signature.\n\n CRITICAL RULES:\n 1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema\n 2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs\n 3. For Pydantic models: define ALL nested fields properly in the schema with correct types\n 4. Simple outputs (single values) can use basic types like str, int, bool\n 5. Use Literal types for enumerated values (e.g., severity levels, status codes)\n\n Examples:\n - BAD: structured_output: str = \"A JSON containing patient data...\"\n - GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)\n \n "
}, },
{ {
"prefix": "Signature Name:", "prefix": "Signature Name:",
@@ -29,7 +25,7 @@
] ]
}, },
"lm": { "lm": {
"model": "gpt-5-2025-08-07", "model": "openai/gpt-4o",
"model_type": "chat", "model_type": "chat",
"cache": true, "cache": true,
"num_retries": 3, "num_retries": 3,
@@ -37,7 +33,7 @@
"launch_kwargs": {}, "launch_kwargs": {},
"train_kwargs": {}, "train_kwargs": {},
"temperature": 1.0, "temperature": 1.0,
"max_completion_tokens": 16000 "max_tokens": 16000
} }
}, },
"metadata": { "metadata": {

View File

@@ -1,10 +1,11 @@
import dspy import dspy
from .utils import OPENROUTER_API_BASE, OPENROUTER_API_KEY
from .modules import signature_generator from .modules import signature_generator
from modaic import PrecompiledAgent, PrecompiledConfig from modaic import PrecompiledAgent, PrecompiledConfig
class PromptToSignatureConfig(PrecompiledConfig): class PromptToSignatureConfig(PrecompiledConfig):
lm: str = "gpt-5-2025-08-07" lm: str = "openrouter/anthropic/claude-haiku-4.5"
refine_lm: str = "gpt-5-2025-08-07" refine_lm: str = "openai/gpt-4o"
max_tokens: int = 16000 max_tokens: int = 16000
temperature: float = 1.0 temperature: float = 1.0
max_attempts_to_refine: int = 5 max_attempts_to_refine: int = 5
@@ -27,6 +28,8 @@ class PromptToSignatureAgent(PrecompiledAgent):
model=config.lm, model=config.lm,
max_tokens=config.max_tokens, max_tokens=config.max_tokens,
temperature=config.temperature, temperature=config.temperature,
api_base=OPENROUTER_API_BASE,
api_key=OPENROUTER_API_KEY
) )
refine_lm = dspy.LM( refine_lm = dspy.LM(
model=config.refine_lm, model=config.refine_lm,

View File

@@ -2,6 +2,7 @@ import dspy
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Literal, Union, get_origin from typing import Any, Dict, List, Optional, Literal, Union, get_origin
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from time import time
class FieldType(str, Enum): class FieldType(str, Enum):
STRING = "str" STRING = "str"
@@ -176,7 +177,9 @@ CRITICAL RULES:
Examples: Examples:
- BAD: structured_output: str = "A JSON containing patient data..." - BAD: structured_output: str = "A JSON containing patient data..."
- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)""" - GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)
"""
) )
signature_name: str = dspy.OutputField( signature_name: str = dspy.OutputField(
desc="Suggested class name for the signature (PascalCase)" desc="Suggested class name for the signature (PascalCase)"
@@ -186,11 +189,14 @@ Examples:
class SignatureGenerator(dspy.Module): class SignatureGenerator(dspy.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.generator = dspy.ChainOfThought(SignatureGeneration) self.generator = dspy.Predict(SignatureGeneration)
def forward(self, prompt: str): def forward(self, prompt: str):
"""Generate DSPy signature and return raw prediction attributes""" """Generate DSPy signature and return raw prediction attributes"""
start_time = time()
result = self.generator(prompt=prompt) result = self.generator(prompt=prompt)
end_time = time()
print(f"Signature generation took {end_time - start_time:.2f} seconds in inference.")
return dspy.Prediction( return dspy.Prediction(
signature_name=result.signature_name, signature_name=result.signature_name,

30
agent/utils.py Normal file
View File

@@ -0,0 +1,30 @@
from typing import Dict, Any
import re
import os
from dotenv import load_dotenv
load_dotenv()
def to_snake_case(name: str) -> str:
"""Convert PascalCase or camelCase string to snake_case."""
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def save_signature_to_file(
result: Dict[str, Any], output_path: str | None = None
) -> str:
"""Save generated signature code to a Python file"""
if not output_path:
signature_name = result.get("signature_name", "generated_signature")
output_path = f"{to_snake_case(signature_name)}.py"
if result.get("code"):
with open(output_path, "w") as f:
f.write(result["code"])
return output_path
else:
raise ValueError("No code to save")
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")

View File

@@ -1,6 +1,6 @@
{ {
"lm": "gpt-5-2025-08-07", "lm": "openrouter/anthropic/claude-haiku-4.5",
"refine_lm": "gpt-5-2025-08-07", "refine_lm": "openai/gpt-4o",
"max_tokens": 16000, "max_tokens": 16000,
"temperature": 1.0, "temperature": 1.0,
"max_attempts_to_refine": 5 "max_attempts_to_refine": 5

86
main.py
View File

@@ -1,9 +1,95 @@
from agent import PromptToSignatureAgent, PromptToSignatureConfig from agent import PromptToSignatureAgent, PromptToSignatureConfig
agent = PromptToSignatureAgent(PromptToSignatureConfig()) agent = PromptToSignatureAgent(PromptToSignatureConfig())
PROMPT = """**INPUT FORMAT**: Raw clinical notes in free-text format, typically 200-2000 words, containing unstructured medical documentation from patient encounters including history, examination findings, diagnoses, treatment plans, and follow-up instructions.
**TASK DESCRIPTION**: You are a medical information extraction system. Analyze the provided clinical notes and perform comprehensive structured data extraction along with risk stratification. Your task involves multiple sub-tasks executed simultaneously.
**EXTRACTION REQUIREMENTS**:
- **Patient Demographics**: Extract full legal name, age (in years), biological sex/gender, date of birth (format: YYYY-MM-DD), medical record number (MRN) if present
- **Primary Diagnosis**: Identify the main diagnosis with corresponding ICD-10 code, include diagnostic certainty level (confirmed/suspected/rule-out)
- **Secondary Diagnoses**: List all comorbidities and additional conditions mentioned, each with ICD-10 codes where applicable
- **Medications**: Extract complete medication list including generic and brand names, dosages with units (mg, mcg, mL), frequency (BID, TID, QID, PRN), route of administration (PO, IV, IM, topical), and duration if specified
- **Allergies**: Document all allergies with allergen name, reaction type (rash, anaphylaxis, nausea, etc.), and severity classification (mild/moderate/severe/life-threatening)
- **Vital Signs**: Extract most recent measurements - blood pressure (systolic/diastolic in mmHg), heart rate (bpm), temperature (°F or °C with unit), respiratory rate (breaths/min), oxygen saturation (%), and pain score (0-10 scale)
- **Laboratory Results**: Identify all lab values mentioned with test name, numerical result, unit of measurement, reference range, and flag if abnormal (high/low/critical)
- **Appointments**: Extract scheduled follow-up dates, appointment types (follow-up, specialist referral, procedure), and provider names
**CLASSIFICATION REQUIREMENTS**:
- **Urgency Level**: Classify the case into one of four categories:
- ROUTINE: Stable patient, chronic condition management, no acute concerns
- URGENT: Requires attention within 24-48 hours, acute but not life-threatening condition
- EMERGENCY: Immediate intervention required, potentially life-threatening presentation
- CRITICAL: Life-threatening emergency requiring immediate intervention (ICU-level care)
**OUTPUT FORMAT**: Return the following schema:
{
"patient_demographics": {
"name": "string",
"age": "integer",
"gender": "string",
"dob": "YYYY-MM-DD",
"mrn": "string or null"
},
"primary_diagnosis": {
"condition": "string",
"icd10_code": "string",
"certainty": "confirmed|suspected|rule-out"
},
"secondary_diagnoses": [
{"condition": "string", "icd10_code": "string"}
],
"medications": [
{
"name": "string",
"dosage": "string",
"frequency": "string",
"route": "string",
"duration": "string or null"
}
],
"allergies": [
{
"allergen": "string",
"reaction": "string",
"severity": "mild|moderate|severe|life-threatening"
}
],
"vital_signs": {
"blood_pressure": "string (systolic/diastolic)",
"heart_rate": "integer",
"temperature": "float with unit",
"respiratory_rate": "integer",
"oxygen_saturation": "integer",
"pain_score": "integer (0-10)"
},
"lab_results": [
{
"test_name": "string",
"value": "float",
"unit": "string",
"reference_range": "string",
"flag": "normal|high|low|critical"
}
],
"appointments": [
{
"date": "YYYY-MM-DD",
"type": "string",
"provider": "string"
}
],
"urgency_classification": {
"level": "ROUTINE|URGENT|EMERGENCY|CRITICAL",
"reasoning": "string (brief explanation for classification)"
}
}
"""
def main(): def main():
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True) agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
#result = agent(PROMPT)
#code = agent.generate_code(result)
#print(code)
if __name__ == "__main__": if __name__ == "__main__":
main() main()