diff --git a/agent.json b/agent.json index 29a9a2e..8464c61 100644 --- a/agent.json +++ b/agent.json @@ -1,5 +1,5 @@ { - "signature_generator.generator.predict": { + "signature_generator.generator": { "traces": [], "train": [], "demos": [], @@ -10,17 +10,13 @@ "prefix": "Prompt:", "description": "Natural language description of the desired functionality." }, - { - "prefix": "Reasoning: Let's think step by step in order to", - "description": "${reasoning}" - }, { "prefix": "Task Description:", "description": "Clear description of what the signature aims to accomplish." }, { "prefix": "Signature Fields:", - "description": "List of input and output fields for the signature.\n\nCRITICAL RULES:\n1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema\n2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs\n3. For Pydantic models: define ALL nested fields properly in the schema with correct types\n4. Simple outputs (single values) can use basic types like str, int, bool\n5. Use Literal types for enumerated values (e.g., severity levels, status codes)\n\nExamples:\n- BAD: structured_output: str = \"A JSON containing patient data...\"\n- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)" + "description": "List of input and output fields for the signature.\n\n CRITICAL RULES:\n 1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema\n 2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs\n 3. For Pydantic models: define ALL nested fields properly in the schema with correct types\n 4. Simple outputs (single values) can use basic types like str, int, bool\n 5. Use Literal types for enumerated values (e.g., severity levels, status codes)\n\n Examples:\n - BAD: structured_output: str = \"A JSON containing patient data...\"\n - GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)\n \n " }, { "prefix": "Signature Name:", @@ -29,7 +25,7 @@ ] }, "lm": { - "model": "gpt-5-2025-08-07", + "model": "openai/gpt-4o", "model_type": "chat", "cache": true, "num_retries": 3, @@ -37,7 +33,7 @@ "launch_kwargs": {}, "train_kwargs": {}, "temperature": 1.0, - "max_completion_tokens": 16000 + "max_tokens": 16000 } }, "metadata": { diff --git a/agent/index.py b/agent/index.py index 1578605..fc2b1a0 100644 --- a/agent/index.py +++ b/agent/index.py @@ -1,10 +1,11 @@ import dspy +from .utils import OPENROUTER_API_BASE, OPENROUTER_API_KEY from .modules import signature_generator from modaic import PrecompiledAgent, PrecompiledConfig class PromptToSignatureConfig(PrecompiledConfig): - lm: str = "gpt-5-2025-08-07" - refine_lm: str = "gpt-5-2025-08-07" + lm: str = "openrouter/anthropic/claude-haiku-4.5" + refine_lm: str = "openai/gpt-4o" max_tokens: int = 16000 temperature: float = 1.0 max_attempts_to_refine: int = 5 @@ -27,6 +28,8 @@ class PromptToSignatureAgent(PrecompiledAgent): model=config.lm, max_tokens=config.max_tokens, temperature=config.temperature, + api_base=OPENROUTER_API_BASE, + api_key=OPENROUTER_API_KEY ) refine_lm = dspy.LM( model=config.refine_lm, diff --git a/agent/modules.py b/agent/modules.py index 1fae1cb..07b0f6d 100644 --- a/agent/modules.py +++ b/agent/modules.py @@ -2,6 +2,7 @@ import dspy from enum import Enum from typing import Any, Dict, List, Optional, Literal, Union, get_origin from pydantic import BaseModel, Field, ValidationError +from time import time class FieldType(str, Enum): STRING = "str" @@ -167,16 +168,18 @@ class SignatureGeneration(dspy.Signature): signature_fields: list[GeneratedField] = dspy.OutputField( desc="""List of input and output fields for the signature. -CRITICAL RULES: -1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema -2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs -3. For Pydantic models: define ALL nested fields properly in the schema with correct types -4. Simple outputs (single values) can use basic types like str, int, bool -5. Use Literal types for enumerated values (e.g., severity levels, status codes) + CRITICAL RULES: + 1. If the prompt describes a structured output with multiple nested fields (e.g., medical records, user profiles, complex forms), you MUST use type='pydantic' with a full PydanticModelSchema + 2. NEVER use type='str' with a description like 'JSON string containing...' for complex outputs + 3. For Pydantic models: define ALL nested fields properly in the schema with correct types + 4. Simple outputs (single values) can use basic types like str, int, bool + 5. Use Literal types for enumerated values (e.g., severity levels, status codes) -Examples: -- BAD: structured_output: str = "A JSON containing patient data..." -- GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields)""" + Examples: + - BAD: structured_output: str = "A JSON containing patient data..." + - GOOD: medical_record: MedicalRecord (with full PydanticModelSchema defining all nested fields) + + """ ) signature_name: str = dspy.OutputField( desc="Suggested class name for the signature (PascalCase)" @@ -186,11 +189,14 @@ Examples: class SignatureGenerator(dspy.Module): def __init__(self): super().__init__() - self.generator = dspy.ChainOfThought(SignatureGeneration) + self.generator = dspy.Predict(SignatureGeneration) def forward(self, prompt: str): """Generate DSPy signature and return raw prediction attributes""" + start_time = time() result = self.generator(prompt=prompt) + end_time = time() + print(f"Signature generation took {end_time - start_time:.2f} seconds in inference.") return dspy.Prediction( signature_name=result.signature_name, diff --git a/agent/utils.py b/agent/utils.py new file mode 100644 index 0000000..7c8e2a6 --- /dev/null +++ b/agent/utils.py @@ -0,0 +1,30 @@ +from typing import Dict, Any +import re +import os +from dotenv import load_dotenv + +load_dotenv() + +def to_snake_case(name: str) -> str: + """Convert PascalCase or camelCase string to snake_case.""" + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def save_signature_to_file( + result: Dict[str, Any], output_path: str | None = None +) -> str: + """Save generated signature code to a Python file""" + if not output_path: + signature_name = result.get("signature_name", "generated_signature") + output_path = f"{to_snake_case(signature_name)}.py" + + if result.get("code"): + with open(output_path, "w") as f: + f.write(result["code"]) + return output_path + else: + raise ValueError("No code to save") + +OPENROUTER_API_BASE = "https://openrouter.ai/api/v1" +OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") diff --git a/config.json b/config.json index b9f060f..ca5585d 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,6 @@ { - "lm": "gpt-5-2025-08-07", - "refine_lm": "gpt-5-2025-08-07", + "lm": "openrouter/anthropic/claude-haiku-4.5", + "refine_lm": "openai/gpt-4o", "max_tokens": 16000, "temperature": 1.0, "max_attempts_to_refine": 5 diff --git a/main.py b/main.py index 457f062..f1ab713 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,95 @@ from agent import PromptToSignatureAgent, PromptToSignatureConfig agent = PromptToSignatureAgent(PromptToSignatureConfig()) +PROMPT = """**INPUT FORMAT**: Raw clinical notes in free-text format, typically 200-2000 words, containing unstructured medical documentation from patient encounters including history, examination findings, diagnoses, treatment plans, and follow-up instructions. +**TASK DESCRIPTION**: You are a medical information extraction system. Analyze the provided clinical notes and perform comprehensive structured data extraction along with risk stratification. Your task involves multiple sub-tasks executed simultaneously. + +**EXTRACTION REQUIREMENTS**: +- **Patient Demographics**: Extract full legal name, age (in years), biological sex/gender, date of birth (format: YYYY-MM-DD), medical record number (MRN) if present +- **Primary Diagnosis**: Identify the main diagnosis with corresponding ICD-10 code, include diagnostic certainty level (confirmed/suspected/rule-out) +- **Secondary Diagnoses**: List all comorbidities and additional conditions mentioned, each with ICD-10 codes where applicable +- **Medications**: Extract complete medication list including generic and brand names, dosages with units (mg, mcg, mL), frequency (BID, TID, QID, PRN), route of administration (PO, IV, IM, topical), and duration if specified +- **Allergies**: Document all allergies with allergen name, reaction type (rash, anaphylaxis, nausea, etc.), and severity classification (mild/moderate/severe/life-threatening) +- **Vital Signs**: Extract most recent measurements - blood pressure (systolic/diastolic in mmHg), heart rate (bpm), temperature (°F or °C with unit), respiratory rate (breaths/min), oxygen saturation (%), and pain score (0-10 scale) +- **Laboratory Results**: Identify all lab values mentioned with test name, numerical result, unit of measurement, reference range, and flag if abnormal (high/low/critical) +- **Appointments**: Extract scheduled follow-up dates, appointment types (follow-up, specialist referral, procedure), and provider names + +**CLASSIFICATION REQUIREMENTS**: +- **Urgency Level**: Classify the case into one of four categories: + - ROUTINE: Stable patient, chronic condition management, no acute concerns + - URGENT: Requires attention within 24-48 hours, acute but not life-threatening condition + - EMERGENCY: Immediate intervention required, potentially life-threatening presentation + - CRITICAL: Life-threatening emergency requiring immediate intervention (ICU-level care) + +**OUTPUT FORMAT**: Return the following schema: +{ + "patient_demographics": { + "name": "string", + "age": "integer", + "gender": "string", + "dob": "YYYY-MM-DD", + "mrn": "string or null" + }, + "primary_diagnosis": { + "condition": "string", + "icd10_code": "string", + "certainty": "confirmed|suspected|rule-out" + }, + "secondary_diagnoses": [ + {"condition": "string", "icd10_code": "string"} + ], + "medications": [ + { + "name": "string", + "dosage": "string", + "frequency": "string", + "route": "string", + "duration": "string or null" + } + ], + "allergies": [ + { + "allergen": "string", + "reaction": "string", + "severity": "mild|moderate|severe|life-threatening" + } + ], + "vital_signs": { + "blood_pressure": "string (systolic/diastolic)", + "heart_rate": "integer", + "temperature": "float with unit", + "respiratory_rate": "integer", + "oxygen_saturation": "integer", + "pain_score": "integer (0-10)" + }, + "lab_results": [ + { + "test_name": "string", + "value": "float", + "unit": "string", + "reference_range": "string", + "flag": "normal|high|low|critical" + } + ], + "appointments": [ + { + "date": "YYYY-MM-DD", + "type": "string", + "provider": "string" + } + ], + "urgency_classification": { + "level": "ROUTINE|URGENT|EMERGENCY|CRITICAL", + "reasoning": "string (brief explanation for classification)" + } +} +""" def main(): agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True) + #result = agent(PROMPT) + #code = agent.generate_code(result) + #print(code) if __name__ == "__main__": main()