From 3d9154d021479a331e9d92ee376f879f289e8089 Mon Sep 17 00:00:00 2001 From: Farouk Adeleke Date: Sun, 30 Nov 2025 11:39:10 -0500 Subject: [PATCH] MIPROv2 Optimized Clinical Impact Judge --- README.md | 73 ++++++++++++++++- agent.json | 67 +++++++++++++++ auto_classes.json | 4 + config.json | 10 +++ pyproject.toml | 12 +++ src/llm_judge/__init__.py | 10 +++ src/llm_judge/cli/__init__.py | 1 + src/llm_judge/cli/run_mipro.py | 105 ++++++++++++++++++++++++ src/llm_judge/data.py | 51 ++++++++++++ src/llm_judge/eval.py | 62 ++++++++++++++ src/llm_judge/metrics.py | 108 +++++++++++++++++++++++++ src/llm_judge/optimizers/__init__.py | 3 + src/llm_judge/optimizers/factory.py | 14 ++++ src/llm_judge/optimizers/gepa.py | 8 ++ src/llm_judge/optimizers/mipro.py | 7 ++ src/llm_judge/providers/__init__.py | 3 + src/llm_judge/providers/bedrock.py | 28 +++++++ src/llm_judge/providers/factory.py | 25 ++++++ src/llm_judge/providers/gemini.py | 24 ++++++ src/llm_judge/providers/ollama_chat.py | 28 +++++++ src/llm_judge/providers/openrouter.py | 38 +++++++++ src/llm_judge/signatures.py | 58 +++++++++++++ 22 files changed, 738 insertions(+), 1 deletion(-) create mode 100644 agent.json create mode 100644 auto_classes.json create mode 100644 config.json create mode 100644 pyproject.toml create mode 100644 src/llm_judge/__init__.py create mode 100644 src/llm_judge/cli/__init__.py create mode 100644 src/llm_judge/cli/run_mipro.py create mode 100644 src/llm_judge/data.py create mode 100644 src/llm_judge/eval.py create mode 100644 src/llm_judge/metrics.py create mode 100644 src/llm_judge/optimizers/__init__.py create mode 100644 src/llm_judge/optimizers/factory.py create mode 100644 src/llm_judge/optimizers/gepa.py create mode 100644 src/llm_judge/optimizers/mipro.py create mode 100644 src/llm_judge/providers/__init__.py create mode 100644 src/llm_judge/providers/bedrock.py create mode 100644 src/llm_judge/providers/factory.py create mode 100644 src/llm_judge/providers/gemini.py create mode 100644 src/llm_judge/providers/ollama_chat.py create mode 100644 src/llm_judge/providers/openrouter.py create mode 100644 src/llm_judge/signatures.py diff --git a/README.md b/README.md index c24b625..a2189f9 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,73 @@ -# clinical-impact-judge-mipro +# WER is Unaware: Assessing How ASR Errors Distort Clinical Understanding in Patient-Facing Dialogue +This repository hosts the code, models, and datasets accompanying the paper. The work investigates how Automatic Speech Recognition (ASR) errors **distort clinical meaning in patient-facing dialogue** — and shows that traditional metrics like Word Error Rate (WER) fail to capture real clinical risk. The project includes scripts for aligning ground-truth utterances to ASR-generated utterances using an **LLM-based semantic aligner**, and optimizing an **LLM-as-a-Judge for clinical impact assessment** using GEPA through DSPy. + +## 📝 Abstract +![WER is Unaware Overview](overview.png) + +As Automatic Speech Recognition (ASR) is increasingly deployed in clinical dialogue, standard evaluations still rely heavily on Word Error Rate (WER). This paper challenges that standard, investigating whether WER or other common metrics correlate with the clinical impact of transcription errors. We establish a gold-standard benchmark by having expert clinicians compare ground-truth utterances to their ASR-generated counterparts, labeling the clinical impact of any discrepancies found in two distinct doctor-patient dialogue datasets. Our analysis reveals that WER and a comprehensive suite of existing metrics correlate poorly with the clinician-assigned risk labels (No, Minimal, or Significant Impact). To bridge this evaluation gap, we introduce an LLM-as-a-Judge, programmatically optimized using GEPA to replicate expert clinical assessment. The optimized judge (Gemini-2.5-Pro) achieves human-comparable performance, obtaining 90% accuracy and a strong Cohen's Îș of 0.816. This work provides a validated, automated framework for moving ASR evaluation beyond simple textual fidelity to a necessary, scalable assessment of safety in clinical dialogue. + +## 🔍 Overview + +We introduce (available here): +- Clinician-annotated clinical-impact dataset: `llm_judge/dataset/primock_data_final_outcomes.csv` +- Semantic LLM-based aligner: `alignment/aligner/` (see `alignment/README.md` for usage) +- LLM-as-a-Judge optimized with GEPA/MIPRO: `llm_judge/` (artifacts in `llm_judge/results/`) +- Evaluations of ASR metrics (code under `alignment/scripts/` and `alignment/results/`) + +## đŸ› ïž Environment Setup + +- Install Python 3.10+ and `uv` (recommended): https://github.com/astral-sh/uv +- Install dependencies: `uv sync` +- Environment: + - OpenRouter (default for LLM calls): `OPENROUTER_API_KEY` (required), `OPENROUTER_MODEL` optional + - Gemini (optional): `GCP_PROJECT_ID`, `GCP_LOCATION` + - Bedrock (optional): `AWS_REGION` + +- Example: run aligner evaluation + ```bash + uv run python alignment/scripts/run_evaluation.py --case-id sample --asr-system demo + ``` + +- Example: run judge (GEPA) + ```bash + uv run python -m llm_judge.cli.run_gepa \ + --data-path llm_judge/dataset/primock_data_final_outcomes.csv \ + --provider openrouter \ + --task-model meta-llama/llama-3.3-70b-instruct \ + --reflection-model anthropic/claude-4-sonnet \ + --output llm_judge/results/clinical_judge_gepa.json + ``` + +## 📁 Folder Structure +- `alignment/` — semantic alignment toolkit (aligner code, scripts, sample data, sample results). +- `llm_judge/` — clinical impact judge (signatures, metrics, providers, optimizers, CLI, bundled dataset, saved judges). + +### Important Files +- `alignment/data/` — example ASR transcripts and ground-truth alignments. +- `alignment/results/` — sample alignment evaluations. +- `llm_judge/dataset/` — clinical-impact dataset. +- `llm_judge/results/` — optimized judges (GEPA, MIPROv2). + +## 📩 Coming Soon + +- Additional dataset metadata and documentation +- Evaluations of 20+ ASR metrics, showing their poor correlation with clinical safety + +## 📄 Paper + +Preprint available on arXiv: https://arxiv.org/abs/2511.16544 + +## 📚 Citation + +```bibtex +@misc{ellis2025werunawareassessingasr, + title={WER is Unaware: Assessing How ASR Errors Distort Clinical Understanding in Patient Facing Dialogue}, + author={Zachary Ellis and Jared Joselowitz and Yash Deo and Yajie He and Anna Kalygina and Aisling Higham and Mana Rahimzadeh and Yan Jia and Ibrahim Habli and Ernest Lim}, + year={2025}, + eprint={2511.16544}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2511.16544}, +} +``` diff --git a/agent.json b/agent.json new file mode 100644 index 0000000..55bcb81 --- /dev/null +++ b/agent.json @@ -0,0 +1,67 @@ +{ + "assess.predict": { + "traces": [], + "train": [], + "demos": [ + { + "ground_truth_conversation": "(42) Doctor: I really hope you get better soon.\n(42) Patient: Thanks a lot. Thanks so much, for your time.\n(43) Doctor: OK. Bye bye.\n(43) Patient: OK. Bye.", + "transcription_conversation": "(42) Doctor: I really hope you get better soon.\n(42) Patient: Thanks a lot. Thanks so much, for your time.\n(43) Doctor: OK. Bye bye.\n(43) Patient: ", + "clinical_impact": "0" + }, + { + "ground_truth_conversation": "(0) Doctor: Hello? Hello. Um, before I go any further, can I confirm your name and your date of birth?\n(0) Patient: Uh, yeah. Uh my name is April, and I'm fifty.\n(1) Doctor: You're fifty, OK. April, how can I help you this afternoon?\n(1) Patient: Well I've just been having this like, cough, for quite a few days. And my nose is running. Um and it's just been super annoying, and it's not going away.", + "transcription_conversation": "(0) Doctor: Hello? Hello. Um, before I go any further, can I confirm your name and your date of birth?\n(0) Patient: Uh, yeah. Uh my name is April, and I'm fifty.\n(1) Doctor: You're fifty, OK. April, how can I help you this afternoon?\n(1) Patient: well i've just been having this like cough for quite a few days and my nose is running and it's just been super annoying and it's not going away", + "clinical_impact": "0" + }, + { + "ground_truth_conversation": "(44) Doctor: Any pain, or strong smell, or having to go more often than normal?\n(44) Patient: So, I, I, I don't really, I have noticed that I don't really go up that high before, so, yeah, I do drink a lot though.\n(45) Doctor: Right, OK. And any , any weight loss or blood in the stool, or urine?\n(45) Patient: Yes. I have . As I said, I have some, I guess, discharge and pain. But, then I there's isn't a lot of discharge.", + "transcription_conversation": "(44) Doctor: Any pain, or strong smell, or having to go more often than normal?\n(44) Patient: So, I, I, I don't really, I have noticed that I don't really go up that high before, so, yeah, I do drink a lot though.\n(45) Doctor: Right, OK. And any , any weight loss or blood in the stool, or urine?\n(45) Patient: yes i have i was said i have some i can get some discharge and bleeding but then i do that there is no discharge", + "clinical_impact": "2" + }, + { + "ground_truth_conversation": "(12) Doctor: Bit of tightness, OK, OK. So, I, unfortunately, I don't have much of your past medical history. Do you have any significant past medical or surgical history you'd like to share?\n(12) Patient: Uh, no, not really. Uh, a bit of eczema and I'm allergic to penicillin.\n(13) Doctor: OK. Are you on any regular medication at all? OK. Righty-ho. Are you having any rash on your body at the moment?\n(13) Patient: Yes there's a bit of blotching .", + "transcription_conversation": "(12) Doctor: Bit of tightness, OK, OK. So, I, unfortunately, I don't have much of your past medical history. Do you have any significant past medical or surgical history you'd like to share?\n(12) Patient: Uh, no, not really. Uh, a bit of eczema and I'm allergic to penicillin.\n(13) Doctor: OK. Are you on any regular medication at all? OK. Righty-ho. Are you having any rash on your body at the moment?\n(13) Patient: yes there's a bit of blocking down here", + "clinical_impact": "2" + } + ], + "signature": { + "instructions": "As a medical expert, your task is to evaluate the clinical impact of errors in a transcribed medical conversation. You will compare a \"Ground Truth Conversation\" with a \"Transcription Conversation\" to identify and assess any discrepancies.\n\nYour assessment must follow two steps:\n1. **Reasoning:** Provide a brief clinical justification for your assessment. Identify the specific error(s) (e.g., word substitution, omission) and explain why they do or do not affect the medical meaning or potential patient care.\n2. **Clinical Impact:** After your reasoning, assign a single numerical score based on the severity of the error(s).\n\nUse the following strict classification guide for the 'Clinical Impact' score:\n\n* **0 = No impact:** Errors are purely cosmetic and do not alter medical meaning. This includes differences in punctuation, capitalization, or filler words (e.g., 'um', 'uh').\n* **1 = Minimal impact:** Some information is missing or changed, but it is NOT critical to immediate diagnosis or treatment decisions. For example, obscuring details about a resolved past issue or losing minor, non-essential context.\n* **2 = Significant impact:** Critical information is missing or incorrect in a way that COULD affect diagnosis, treatment decisions, or patient safety. Be especially vigilant for errors related to:\n * Medication names (e.g., \"Ibuprofen\" becomes \"upper brooklyn\")\n * Key symptoms (e.g., \"pain\" is changed to \"bleeding\")\n * Dosages, allergies, or critical patient history details.\n\nFirst, write your reasoning, then provide the final number for the clinical impact.", + "fields": [ + { + "prefix": "Ground Truth Conversation:", + "description": "${ground_truth_conversation}" + }, + { + "prefix": "Transcription Conversation:", + "description": "${transcription_conversation}" + }, + { + "prefix": "Reasoning:", + "description": "Brief clinical justification for the assessment." + }, + { + "prefix": "Clinical Impact:", + "description": "Clinical impact class (return ONLY the number):\n 0 = No impact: cosmetic differences only (punctuation, capitalization, filler words)\n 1 = Minimal impact: some information missing/changed but NOT critical to diagnosis or treatment decisions \n 2 = Significant impact: missing/incorrect information that COULD affect diagnosis, treatment, or patient safety\n Return ONLY: 0, 1, or 2" + } + ] + }, + "lm": { + "model": "openrouter/google/gemini-2.5-pro", + "model_type": "chat", + "cache": true, + "num_retries": 3, + "finetuning_model": null, + "launch_kwargs": {}, + "train_kwargs": {}, + "temperature": 0.1, + "max_tokens": 8000 + } + }, + "metadata": { + "dependency_versions": { + "python": "3.13", + "dspy": "3.0.4", + "cloudpickle": "3.1" + } + } +} \ No newline at end of file diff --git a/auto_classes.json b/auto_classes.json new file mode 100644 index 0000000..9764aba --- /dev/null +++ b/auto_classes.json @@ -0,0 +1,4 @@ +{ + "AutoConfig": "src.llm_judge.signatures.ClinicalImpactJudgeConfig", + "AutoAgent": "src.llm_judge.signatures.ClinicalImpactJudge" +} \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..b15a3dc --- /dev/null +++ b/config.json @@ -0,0 +1,10 @@ +{ + "task_model": "openrouter/google/gemini-2.5-pro", + "reflection_model": null, + "max_tokens": 8000, + "temperature": 0.1, + "test_size": 50, + "val_size": 30, + "seed": 42, + "auto": "medium" +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..74b5d30 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "clinical-impact-judge-mipro" +version = "0.1.0" +description = "LLM transcript alignment and evaluation toolkit" +readme = "README.md" +requires-python = ">=3.11" +dependencies = ["openai>=1.35.0", "num2words>=0.5.13", "python-dotenv>=1.2.1", "dspy>=3.0.4", "modaic>=0.4.1", "scikit-learn>=1.7.2", "vertexai>=1.71.1"] + +[project.optional-dependencies] +plot = ["matplotlib>=3.8"] +dev = ["pytest>=8.0"] + diff --git a/src/llm_judge/__init__.py b/src/llm_judge/__init__.py new file mode 100644 index 0000000..44398e5 --- /dev/null +++ b/src/llm_judge/__init__.py @@ -0,0 +1,10 @@ +"""Lightweight helpers for running DSPy judges with GEPA or MIPRO.""" + +__all__ = [ + "signatures", + "metrics", + "data", + "models", + "optimizers", + "eval", +] diff --git a/src/llm_judge/cli/__init__.py b/src/llm_judge/cli/__init__.py new file mode 100644 index 0000000..742a1a7 --- /dev/null +++ b/src/llm_judge/cli/__init__.py @@ -0,0 +1 @@ +# CLI entrypoints live here for `python -m llm_judge.cli.run_*`. diff --git a/src/llm_judge/cli/run_mipro.py b/src/llm_judge/cli/run_mipro.py new file mode 100644 index 0000000..a0d089d --- /dev/null +++ b/src/llm_judge/cli/run_mipro.py @@ -0,0 +1,105 @@ +import argparse +import os + +from dotenv import load_dotenv + +from src.llm_judge import data as data_utils +from src.llm_judge import eval as eval_utils +from src.llm_judge import metrics, signatures +from src.llm_judge.optimizers import get_optimizer +from src.llm_judge.providers import setup_models + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run MIPROv2 optimization for clinical impact judge." + ) + parser.add_argument( + "--data-path", type=str, default=os.getenv("DATA_PATH"), help="CSV file path." + ) + parser.add_argument( + "--provider", + type=str, + default="openrouter", + choices=["gemini", "bedrock", "openrouter"], + ) + parser.add_argument("--task-model", type=str, default="google/gemini-2.5-pro") + parser.add_argument("--output", type=str, default="clinical_judge_mipro.json") + parser.add_argument("--test-size", type=int, default=50) + parser.add_argument("--val-size", type=int, default=30) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--auto", + type=str, + default="medium", + help="MIPRO auto level: light|medium|heavy", + ) + return parser.parse_args() + + +def main(): + load_dotenv() + args = parse_args() + + if not args.data_path: + raise SystemExit("Please provide --data-path or set DATA_PATH.") + + print("=" * 80) + print("DSPy Clinical Impact Judge - MIPROv2") + print("=" * 80) + + # Data + df = data_utils.load_dataset(args.data_path) + trainset, valset, testset = data_utils.build_splits( + df, test_size=args.test_size, val_size=args.val_size, random_state=args.seed + ) + print(f"Train: {len(trainset)} | Val: {len(valset)} | Test: {len(testset)}") + + # Models + task_lm, _ = setup_models(args.provider, task_model=args.task_model) + print(f"✓ Connected to provider={args.provider} task_model={args.task_model}") + + # Optimizer + optimizer = get_optimizer( + "mipro", + metric=metrics.simple_metric, + auto=args.auto, + seed=args.seed, + ) + model_base = ( + f"{args.provider}/" + if args.provider == "openrouter" + or args.provider == "bedrock" + or args.provider == "ollama_chat" + else "vertex_ai/" + ) + + config = signatures.ClinicalImpactJudgeConfig( + task_model=model_base + args.task_model, + test_size=args.test_size, + val_size=args.val_size, + seed=args.seed, + auto=args.auto, + max_tokens=8000, + ) + judge = signatures.ClinicalImpactJudge(config) + + optimized_judge = optimizer.compile( + judge, + trainset=trainset, + requires_permission_to_run=False, + ) + optimized_judge.save(args.output) + optimized_judge.push_to_hub( + "jaredjoss123/clinical-impact-judge-mipro", + with_code=True, + commit_message="MIPROv2 Optimized Clinical Impact Judge", + ) + print(f"✓ Saved optimized judge to {args.output}") + + # Evaluate on test + eval_utils.evaluate_judge(optimized_judge, testset, name="MIPROv2 Optimized") + + +if __name__ == "__main__": + main() diff --git a/src/llm_judge/data.py b/src/llm_judge/data.py new file mode 100644 index 0000000..0db3cb3 --- /dev/null +++ b/src/llm_judge/data.py @@ -0,0 +1,51 @@ +from typing import List, Tuple + +import dspy +import pandas as pd +from sklearn.model_selection import train_test_split + + +def load_dataset(csv_path: str) -> pd.DataFrame: + """Load dataset and drop rows without final_outcome.""" + df = pd.read_csv(csv_path) + return df.dropna(subset=["final_outcome"]) + + +def split_dataset( + df: pd.DataFrame, + test_size: int = 50, + val_size: int = 30, + random_state: int = 42, +) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Split into train/val/test with fixed sizes and stratification.""" + y = df["final_outcome"] + train_val_df, test_df = train_test_split( + df, test_size=test_size, stratify=y, random_state=random_state + ) + train_df, val_df = train_test_split( + train_val_df, + test_size=val_size, + stratify=train_val_df["final_outcome"], + random_state=random_state, + ) + return train_df, val_df, test_df + + +def create_dspy_example(row) -> dspy.Example: + """Convert dataframe row to DSPy Example.""" + return dspy.Example( + ground_truth_conversation=str(row["fer_gt_context"]), + transcription_conversation=str(row["fer_hyp_context"]), + clinical_impact=str(int(row["final_outcome"])), + ).with_inputs("ground_truth_conversation", "transcription_conversation") + + +def build_splits( + df: pd.DataFrame, test_size: int = 50, val_size: int = 30, random_state: int = 42 +) -> Tuple[List[dspy.Example], List[dspy.Example], List[dspy.Example]]: + """Create DSPy train/val/test example lists from dataframe splits.""" + train_df, val_df, test_df = split_dataset(df, test_size, val_size, random_state) + trainset = [create_dspy_example(row) for _, row in train_df.iterrows()] + valset = [create_dspy_example(row) for _, row in val_df.iterrows()] + testset = [create_dspy_example(row) for _, row in test_df.iterrows()] + return trainset, valset, testset diff --git a/src/llm_judge/eval.py b/src/llm_judge/eval.py new file mode 100644 index 0000000..538f42d --- /dev/null +++ b/src/llm_judge/eval.py @@ -0,0 +1,62 @@ +import pandas as pd +from sklearn.metrics import classification_report, cohen_kappa_score, confusion_matrix + +from .metrics import parse_label + + +def evaluate_judge(judge, testset, name="Judge"): + """Evaluate a judge on a testset and print metrics.""" + print("\n" + "=" * 80) + print(f"EVALUATING: {name}") + print("=" * 80) + + results = [] + for idx, example in enumerate(testset): + try: + prediction = judge( + ground_truth_conversation=example.ground_truth_conversation, + transcription_conversation=example.transcription_conversation, + ) + pred_label = parse_label(prediction.clinical_impact) + true_label = int(example.clinical_impact) + if pred_label is not None: + results.append({"true_label": true_label, "pred_label": pred_label}) + except Exception as exc: # pragma: no cover - runtime guardrail + print(f"Error on example {idx}: {exc}") + continue + + if not results: + return None, 0, 0 + + results_df = pd.DataFrame(results) + true_labels = results_df["true_label"].values + pred_labels = results_df["pred_label"].values + + accuracy = (true_labels == pred_labels).mean() * 100 + kappa = cohen_kappa_score(true_labels, pred_labels) + + print(f"\nAccuracy: {accuracy:.2f}%") + print(f"Cohen's Kappa: {kappa:.3f}") + print("\nClassification Report:") + print( + classification_report( + true_labels, + pred_labels, + target_names=["0 (No impact)", "1 (Minimal)", "2 (Significant)"], + zero_division=0, + ) + ) + print("\nConfusion Matrix:") + cm = confusion_matrix(true_labels, pred_labels) + print(" Predicted") + print(" 0 1 2") + for i, row_label in enumerate(["Actual 0", "Actual 1", "Actual 2"]): + print(f"{row_label:10s} {cm[i][0]:4d} {cm[i][1]:4d} {cm[i][2]:4d}") + + for class_label in [0, 1, 2]: + mask = true_labels == class_label + if mask.sum() > 0: + recall = (true_labels[mask] == pred_labels[mask]).mean() * 100 + print(f"Class {class_label} recall: {recall:.1f}%") + + return results_df, accuracy, kappa diff --git a/src/llm_judge/metrics.py b/src/llm_judge/metrics.py new file mode 100644 index 0000000..d25fdb3 --- /dev/null +++ b/src/llm_judge/metrics.py @@ -0,0 +1,108 @@ +import json +import re +from typing import Optional + +import dspy + +COST_MATRIX = [ + [1.2, 0.3, -1.0], + [0.3, 1.5, 0.5], + [-1.2, 0.4, 1.5], +] + + +def parse_label(label_str: str) -> Optional[int]: + try: + label_str = str(label_str).strip() + if label_str in {"0", "1", "2"}: + return int(label_str) + + json_match = re.search(r"\{.*\}", label_str, re.DOTALL) + if json_match: + obj = json.loads(json_match.group(0)) + val = obj.get("clinical_impact") + if val in [0, 1, 2] or str(val) in "012": + return int(val) + + num_match = re.search(r"\b([0-2])\b", label_str) + if num_match: + return int(num_match.group(1)) + except Exception: + return None + return None + + +def gepa_feedback_metric( + example, prediction, trace=None, pred_name=None, pred_trace=None +): + true_label = int(example.clinical_impact) + pred_label = parse_label(prediction.clinical_impact) + + if pred_label is None: + feedback = ( + f"PARSING ERROR: The model failed to output a valid class (0, 1, or 2). " + f"Raw output: '{prediction.clinical_impact}'. " + f"The model MUST return ONLY the number 0, 1, or 2 as specified in the output field description. " + f"Consider emphasizing in the instructions: output format must be strictly a single digit." + ) + return dspy.Prediction(score=-2.0, feedback=feedback) + + score = COST_MATRIX[true_label][pred_label] + + # Generate detailed feedback based on the prediction outcome + if pred_label == true_label: + class_names = {0: "No impact", 1: "Minimal impact", 2: "Significant impact"} + feedback = ( + f"CORRECT: Correctly identified as Class {true_label} ({class_names[true_label]}). " + f"The model's reasoning was appropriate for this classification. " + f"Continue using similar reasoning patterns for this type of case." + ) + else: + if true_label == 0 and pred_label > 0: + feedback = ( + f"OVER-CLASSIFICATION: Predicted Class {pred_label} but should be Class 0 (No impact). " + f"The transcription differences are cosmetic (punctuation, capitalization, filler words) " + f"and do NOT affect clinical meaning. The model should be MORE LENIENT with minor differences " + f"and focus ONLY on content that affects diagnosis or treatment decisions." + ) + elif true_label == 1 and pred_label == 0: + feedback = ( + f"UNDER-CLASSIFICATION: Predicted Class 0 but should be Class 1 (Minimal impact). " + f"While not critical to diagnosis/treatment, some clinically relevant information was " + f"missing or changed. The model should be MORE SENSITIVE to information changes, " + f"even if they don't directly affect critical decisions." + ) + elif true_label == 1 and pred_label == 2: + feedback = ( + f"OVER-CLASSIFICATION: Predicted Class 2 but should be Class 1 (Minimal impact). " + f"The information changes are not critical enough to affect diagnosis or patient safety. " + f"Reserve Class 2 ONLY for errors that COULD directly affect diagnosis, treatment, or safety. " + f"The model should distinguish between 'some information missing' vs 'critical information missing'." + ) + elif true_label == 2 and pred_label < 2: + feedback = ( + f"CRITICAL MISS: Predicted Class {pred_label} but should be Class 2 (Significant impact). " + f"This is a HIGH-PRIORITY error. The transcription contained missing/incorrect information " + f"that COULD affect diagnosis, treatment, or patient safety. The model MUST be MORE SENSITIVE " + f"to clinically critical information like symptoms, medications, measurements, or diagnoses. " + f"Look for: changed medical terms, missing symptoms, altered measurements, or omitted diagnoses." + ) + else: + feedback = ( + f"MAJOR ERROR: Predicted Class {pred_label} but should be Class {true_label}. " + f"This is a large classification error spanning 2 severity levels. " + f"The model needs to fundamentally reassess its criteria for clinical impact. " + f"Review the distinction between cosmetic changes, information changes, and critical errors." + ) + + feedback += f" [True: {true_label}, Predicted: {pred_label}]" + + return dspy.Prediction(score=score, feedback=feedback) + + +def simple_metric(example, prediction, trace=None): + true_label = int(example.clinical_impact) + pred_label = parse_label(prediction.clinical_impact) + if pred_label is None: + return -2.0 + return COST_MATRIX[true_label][pred_label] diff --git a/src/llm_judge/optimizers/__init__.py b/src/llm_judge/optimizers/__init__.py new file mode 100644 index 0000000..000c02e --- /dev/null +++ b/src/llm_judge/optimizers/__init__.py @@ -0,0 +1,3 @@ +from .factory import get_optimizer + +__all__ = ["get_optimizer"] diff --git a/src/llm_judge/optimizers/factory.py b/src/llm_judge/optimizers/factory.py new file mode 100644 index 0000000..fad41d7 --- /dev/null +++ b/src/llm_judge/optimizers/factory.py @@ -0,0 +1,14 @@ +from typing import Any, Optional + +from .gepa import build_gepa +from .mipro import build_mipro + + +def get_optimizer(name: str, metric, reflection_lm: Optional[Any] = None, **kwargs): + """Return a configured optimizer by name.""" + name = name.lower() + if name == "gepa": + return build_gepa(metric=metric, reflection_lm=reflection_lm, **kwargs) + if name in {"mipro", "miprov2"}: + return build_mipro(metric=metric, **kwargs) + raise ValueError(f"Unsupported optimizer: {name}") diff --git a/src/llm_judge/optimizers/gepa.py b/src/llm_judge/optimizers/gepa.py new file mode 100644 index 0000000..c83b22e --- /dev/null +++ b/src/llm_judge/optimizers/gepa.py @@ -0,0 +1,8 @@ +from dspy.teleprompt import GEPA + + +def build_gepa(metric, reflection_lm=None, **kwargs): + """Construct a GEPA optimizer.""" + if reflection_lm is not None: + kwargs["reflection_lm"] = reflection_lm + return GEPA(metric=metric, **kwargs) diff --git a/src/llm_judge/optimizers/mipro.py b/src/llm_judge/optimizers/mipro.py new file mode 100644 index 0000000..1ff4e7e --- /dev/null +++ b/src/llm_judge/optimizers/mipro.py @@ -0,0 +1,7 @@ +from dspy.teleprompt import MIPROv2 + + +def build_mipro(metric, **kwargs): + """Construct a MIPROv2 optimizer.""" + allowed = {k: v for k, v in kwargs.items() if k in {"auto", "seed"}} + return MIPROv2(metric=metric, **allowed) diff --git a/src/llm_judge/providers/__init__.py b/src/llm_judge/providers/__init__.py new file mode 100644 index 0000000..5369c0a --- /dev/null +++ b/src/llm_judge/providers/__init__.py @@ -0,0 +1,3 @@ +from .factory import setup_models + +__all__ = ["setup_models"] diff --git a/src/llm_judge/providers/bedrock.py b/src/llm_judge/providers/bedrock.py new file mode 100644 index 0000000..1ec7158 --- /dev/null +++ b/src/llm_judge/providers/bedrock.py @@ -0,0 +1,28 @@ +from typing import Optional, Tuple +import os +import dspy +from dotenv import load_dotenv + +load_dotenv() + + +def init_bedrock( + task_model: str, + reflection_model: Optional[str] = None, + region: str = "us-east-1", + max_tokens: int = 1000, +) -> Tuple[dspy.LM, Optional[dspy.LM]]: + """Configure DSPy to use AWS Bedrock.""" + task_lm = dspy.LM( + f"bedrock/{task_model}", + region_name=(os.getenv("AWS_REGION", region)), + max_tokens=max_tokens, + ) + dspy.settings.configure(lm=task_lm) + + reflection_lm = None + if reflection_model: + reflection_lm = dspy.LM( + f"bedrock/{reflection_model}", region_name=region, max_tokens=max_tokens + ) + return task_lm, reflection_lm diff --git a/src/llm_judge/providers/factory.py b/src/llm_judge/providers/factory.py new file mode 100644 index 0000000..1a5529a --- /dev/null +++ b/src/llm_judge/providers/factory.py @@ -0,0 +1,25 @@ +from typing import Optional + +from .bedrock import init_bedrock +from .gemini import init_gemini +from .openrouter import init_openrouter +from .ollama_chat import init_ollama_chat + + +def setup_models( + provider: str, + task_model: str, + reflection_model: Optional[str] = None, + **kwargs, +): + """Initialize task/reflection LMs and configure DSPy.""" + provider = provider.lower() + if provider == "gemini": + return init_gemini(task_model, reflection_model=reflection_model, **kwargs) + if provider == "bedrock": + return init_bedrock(task_model, reflection_model=reflection_model, **kwargs) + if provider == "openrouter": + return init_openrouter(task_model, reflection_model=reflection_model, **kwargs) + if provider == "ollam_chat": + return init_ollama_chat(task_model, reflection_model=reflection_model, **kwargs) + raise ValueError(f"Unsupported provider: {provider}") diff --git a/src/llm_judge/providers/gemini.py b/src/llm_judge/providers/gemini.py new file mode 100644 index 0000000..4dedd26 --- /dev/null +++ b/src/llm_judge/providers/gemini.py @@ -0,0 +1,24 @@ +import os +from typing import Optional, Tuple + +import dspy +import vertexai + + +def init_gemini( + task_model: str, + reflection_model: Optional[str] = None, + max_tokens: int = 8000, +) -> Tuple[dspy.LM, Optional[dspy.LM]]: + """Configure DSPy to use Gemini via Vertex AI.""" + project = os.getenv("GCP_PROJECT_ID", "your-project-id") + location = os.getenv("GCP_LOCATION", "us-central1") + vertexai.init(project=project, location=location) + + task_lm = dspy.LM(f"vertex_ai/{task_model}", max_tokens=max_tokens) + dspy.settings.configure(lm=task_lm) + + reflection_lm = None + if reflection_model: + reflection_lm = dspy.LM(f"vertex_ai/{reflection_model}", max_tokens=max_tokens) + return task_lm, reflection_lm diff --git a/src/llm_judge/providers/ollama_chat.py b/src/llm_judge/providers/ollama_chat.py new file mode 100644 index 0000000..edd717c --- /dev/null +++ b/src/llm_judge/providers/ollama_chat.py @@ -0,0 +1,28 @@ +from typing import Optional, Tuple + +import dspy + + +def init_ollama_chat( + task_model: str, + reflection_model: Optional[str] = None, + max_tokens: int = 1000, +) -> Tuple[dspy.LM, Optional[dspy.LM]]: + """Configure DSPy to use Ollama Chat.""" + task_lm = dspy.LM( + f"ollama_chat/{task_model}", + api_base="http://localhost:11434", + api_key="", + max_tokens=max_tokens, + ) + dspy.settings.configure(lm=task_lm) + + reflection_lm = None + if reflection_model: + reflection_lm = dspy.LM( + f"ollama_chat/{reflection_model}", + api_base="http://localhost:11434", + api_key="", + max_tokens=max_tokens, + ) + return task_lm, reflection_lm diff --git a/src/llm_judge/providers/openrouter.py b/src/llm_judge/providers/openrouter.py new file mode 100644 index 0000000..a8ba769 --- /dev/null +++ b/src/llm_judge/providers/openrouter.py @@ -0,0 +1,38 @@ +import os +from typing import Optional, Tuple + +import dspy + + +def init_openrouter( + task_model: str, + reflection_model: Optional[str] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + max_tokens: int = 4000, +) -> Tuple[dspy.LM, Optional[dspy.LM]]: + """Configure DSPy to use OpenRouter.""" + base_url = base_url or os.getenv( + "OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1" + ) + api_key = api_key or os.getenv("OPENROUTER_API_KEY") + if not api_key: + raise ValueError("OPENROUTER_API_KEY is required for OpenRouter provider") + + task_lm = dspy.LM( + f"openrouter/{task_model}", + api_base=base_url, + api_key=api_key, + max_tokens=max_tokens, + ) + dspy.settings.configure(lm=task_lm) + + reflection_lm = None + if reflection_model: + reflection_lm = dspy.LM( + f"openrouter/{reflection_model}", + api_base=base_url, + api_key=api_key, + max_tokens=max_tokens, + ) + return task_lm, reflection_lm diff --git a/src/llm_judge/signatures.py b/src/llm_judge/signatures.py new file mode 100644 index 0000000..7136703 --- /dev/null +++ b/src/llm_judge/signatures.py @@ -0,0 +1,58 @@ +import dspy +from modaic import PrecompiledAgent, PrecompiledConfig +from typing import Optional + + +class ClinicalImpactAssessment(dspy.Signature): + """Assess the clinical impact of transcription errors in medical conversations. + + Compare the ground truth conversation with the transcription conversation and determine + if errors would affect patient care. Focus on THREE distinct severity levels. + """ + + ground_truth_conversation = dspy.InputField() + transcription_conversation = dspy.InputField() + reasoning = dspy.OutputField( + desc="Brief clinical justification for the assessment." + ) + clinical_impact = dspy.OutputField( + desc="""Clinical impact class (return ONLY the number): + 0 = No impact: cosmetic differences only (punctuation, capitalization, filler words) + 1 = Minimal impact: some information missing/changed but NOT critical to diagnosis or treatment decisions + 2 = Significant impact: missing/incorrect information that COULD affect diagnosis, treatment, or patient safety + Return ONLY: 0, 1, or 2""" + ) + + +class ClinicalImpactJudgeConfig(PrecompiledConfig): + task_model: str = "gemini-2.5-pro" + reflection_model: Optional[str] = None # for GEPA runs + max_tokens: int = 8000 + temperature: float = 0.1 + test_size: Optional[int] = 50 + val_size: Optional[int] = 30 + seed: Optional[int] = 42 + auto: Optional[str] = "medium" + + +class ClinicalImpactJudge(PrecompiledAgent): + """LLM Judge for assessing clinical impact.""" + + config: ClinicalImpactJudgeConfig + + def __init__(self, config: ClinicalImpactJudgeConfig, **kwargs): + super().__init__(config, **kwargs) + self.assess = dspy.ChainOfThought(ClinicalImpactAssessment) + self.assess.set_lm( + dspy.LM( + config.task_model, + max_tokens=config.max_tokens, + temperature=config.temperature, + ) + ) + + def forward(self, ground_truth_conversation, transcription_conversation): + return self.assess( + ground_truth_conversation=ground_truth_conversation, + transcription_conversation=transcription_conversation, + )