From 0272ec2eb69a6ab6643259647d2c3acdec3ca024 Mon Sep 17 00:00:00 2001 From: Farouk Adeleke Date: Sun, 30 Nov 2025 06:19:03 -0500 Subject: [PATCH] Unoptimized Clinical Impact Judge --- README.md | 73 +++++++++++++- agent.json | 46 +++++++++ auto_classes.json | 4 + config.json | 10 ++ pyproject.toml | 12 +++ src/llm_judge/__init__.py | 10 ++ src/llm_judge/cli/__init__.py | 1 + src/llm_judge/cli/run_gepa.py | 131 +++++++++++++++++++++++++ src/llm_judge/data.py | 51 ++++++++++ src/llm_judge/eval.py | 62 ++++++++++++ src/llm_judge/metrics.py | 108 ++++++++++++++++++++ src/llm_judge/optimizers/__init__.py | 3 + src/llm_judge/optimizers/factory.py | 14 +++ src/llm_judge/optimizers/gepa.py | 8 ++ src/llm_judge/optimizers/mipro.py | 7 ++ src/llm_judge/providers/__init__.py | 3 + src/llm_judge/providers/bedrock.py | 28 ++++++ src/llm_judge/providers/factory.py | 25 +++++ src/llm_judge/providers/gemini.py | 24 +++++ src/llm_judge/providers/ollama_chat.py | 28 ++++++ src/llm_judge/providers/openrouter.py | 38 +++++++ src/llm_judge/signatures.py | 58 +++++++++++ 22 files changed, 743 insertions(+), 1 deletion(-) create mode 100644 agent.json create mode 100644 auto_classes.json create mode 100644 config.json create mode 100644 pyproject.toml create mode 100644 src/llm_judge/__init__.py create mode 100644 src/llm_judge/cli/__init__.py create mode 100644 src/llm_judge/cli/run_gepa.py create mode 100644 src/llm_judge/data.py create mode 100644 src/llm_judge/eval.py create mode 100644 src/llm_judge/metrics.py create mode 100644 src/llm_judge/optimizers/__init__.py create mode 100644 src/llm_judge/optimizers/factory.py create mode 100644 src/llm_judge/optimizers/gepa.py create mode 100644 src/llm_judge/optimizers/mipro.py create mode 100644 src/llm_judge/providers/__init__.py create mode 100644 src/llm_judge/providers/bedrock.py create mode 100644 src/llm_judge/providers/factory.py create mode 100644 src/llm_judge/providers/gemini.py create mode 100644 src/llm_judge/providers/ollama_chat.py create mode 100644 src/llm_judge/providers/openrouter.py create mode 100644 src/llm_judge/signatures.py diff --git a/README.md b/README.md index ef44548..a2189f9 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,73 @@ -# clinical-impact-judge +# WER is Unaware: Assessing How ASR Errors Distort Clinical Understanding in Patient-Facing Dialogue +This repository hosts the code, models, and datasets accompanying the paper. The work investigates how Automatic Speech Recognition (ASR) errors **distort clinical meaning in patient-facing dialogue** — and shows that traditional metrics like Word Error Rate (WER) fail to capture real clinical risk. The project includes scripts for aligning ground-truth utterances to ASR-generated utterances using an **LLM-based semantic aligner**, and optimizing an **LLM-as-a-Judge for clinical impact assessment** using GEPA through DSPy. + +## 📝 Abstract +![WER is Unaware Overview](overview.png) + +As Automatic Speech Recognition (ASR) is increasingly deployed in clinical dialogue, standard evaluations still rely heavily on Word Error Rate (WER). This paper challenges that standard, investigating whether WER or other common metrics correlate with the clinical impact of transcription errors. We establish a gold-standard benchmark by having expert clinicians compare ground-truth utterances to their ASR-generated counterparts, labeling the clinical impact of any discrepancies found in two distinct doctor-patient dialogue datasets. Our analysis reveals that WER and a comprehensive suite of existing metrics correlate poorly with the clinician-assigned risk labels (No, Minimal, or Significant Impact). To bridge this evaluation gap, we introduce an LLM-as-a-Judge, programmatically optimized using GEPA to replicate expert clinical assessment. The optimized judge (Gemini-2.5-Pro) achieves human-comparable performance, obtaining 90% accuracy and a strong Cohen's Îș of 0.816. This work provides a validated, automated framework for moving ASR evaluation beyond simple textual fidelity to a necessary, scalable assessment of safety in clinical dialogue. + +## 🔍 Overview + +We introduce (available here): +- Clinician-annotated clinical-impact dataset: `llm_judge/dataset/primock_data_final_outcomes.csv` +- Semantic LLM-based aligner: `alignment/aligner/` (see `alignment/README.md` for usage) +- LLM-as-a-Judge optimized with GEPA/MIPRO: `llm_judge/` (artifacts in `llm_judge/results/`) +- Evaluations of ASR metrics (code under `alignment/scripts/` and `alignment/results/`) + +## đŸ› ïž Environment Setup + +- Install Python 3.10+ and `uv` (recommended): https://github.com/astral-sh/uv +- Install dependencies: `uv sync` +- Environment: + - OpenRouter (default for LLM calls): `OPENROUTER_API_KEY` (required), `OPENROUTER_MODEL` optional + - Gemini (optional): `GCP_PROJECT_ID`, `GCP_LOCATION` + - Bedrock (optional): `AWS_REGION` + +- Example: run aligner evaluation + ```bash + uv run python alignment/scripts/run_evaluation.py --case-id sample --asr-system demo + ``` + +- Example: run judge (GEPA) + ```bash + uv run python -m llm_judge.cli.run_gepa \ + --data-path llm_judge/dataset/primock_data_final_outcomes.csv \ + --provider openrouter \ + --task-model meta-llama/llama-3.3-70b-instruct \ + --reflection-model anthropic/claude-4-sonnet \ + --output llm_judge/results/clinical_judge_gepa.json + ``` + +## 📁 Folder Structure +- `alignment/` — semantic alignment toolkit (aligner code, scripts, sample data, sample results). +- `llm_judge/` — clinical impact judge (signatures, metrics, providers, optimizers, CLI, bundled dataset, saved judges). + +### Important Files +- `alignment/data/` — example ASR transcripts and ground-truth alignments. +- `alignment/results/` — sample alignment evaluations. +- `llm_judge/dataset/` — clinical-impact dataset. +- `llm_judge/results/` — optimized judges (GEPA, MIPROv2). + +## 📩 Coming Soon + +- Additional dataset metadata and documentation +- Evaluations of 20+ ASR metrics, showing their poor correlation with clinical safety + +## 📄 Paper + +Preprint available on arXiv: https://arxiv.org/abs/2511.16544 + +## 📚 Citation + +```bibtex +@misc{ellis2025werunawareassessingasr, + title={WER is Unaware: Assessing How ASR Errors Distort Clinical Understanding in Patient Facing Dialogue}, + author={Zachary Ellis and Jared Joselowitz and Yash Deo and Yajie He and Anna Kalygina and Aisling Higham and Mana Rahimzadeh and Yan Jia and Ibrahim Habli and Ernest Lim}, + year={2025}, + eprint={2511.16544}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2511.16544}, +} +``` diff --git a/agent.json b/agent.json new file mode 100644 index 0000000..cb0c01e --- /dev/null +++ b/agent.json @@ -0,0 +1,46 @@ +{ + "assess.predict": { + "traces": [], + "train": [], + "demos": [], + "signature": { + "instructions": "Assess the clinical impact of transcription errors in medical conversations.\n\nCompare the ground truth conversation with the transcription conversation and determine\nif errors would affect patient care. Focus on THREE distinct severity levels.", + "fields": [ + { + "prefix": "Ground Truth Conversation:", + "description": "${ground_truth_conversation}" + }, + { + "prefix": "Transcription Conversation:", + "description": "${transcription_conversation}" + }, + { + "prefix": "Reasoning:", + "description": "Brief clinical justification for the assessment." + }, + { + "prefix": "Clinical Impact:", + "description": "Clinical impact class (return ONLY the number):\n 0 = No impact: cosmetic differences only (punctuation, capitalization, filler words)\n 1 = Minimal impact: some information missing/changed but NOT critical to diagnosis or treatment decisions \n 2 = Significant impact: missing/incorrect information that COULD affect diagnosis, treatment, or patient safety\n Return ONLY: 0, 1, or 2" + } + ] + }, + "lm": { + "model": "openroutergoogle/gemini-2.5-pro", + "model_type": "chat", + "cache": true, + "num_retries": 3, + "finetuning_model": null, + "launch_kwargs": {}, + "train_kwargs": {}, + "temperature": 0.1, + "max_tokens": 8000 + } + }, + "metadata": { + "dependency_versions": { + "python": "3.13", + "dspy": "3.0.4", + "cloudpickle": "3.1" + } + } +} \ No newline at end of file diff --git a/auto_classes.json b/auto_classes.json new file mode 100644 index 0000000..9764aba --- /dev/null +++ b/auto_classes.json @@ -0,0 +1,4 @@ +{ + "AutoConfig": "src.llm_judge.signatures.ClinicalImpactJudgeConfig", + "AutoAgent": "src.llm_judge.signatures.ClinicalImpactJudge" +} \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..62206ab --- /dev/null +++ b/config.json @@ -0,0 +1,10 @@ +{ + "task_model": "openroutergoogle/gemini-2.5-pro", + "reflection_model": "openrouteranthropic/claude-4-sonnet", + "max_tokens": 8000, + "temperature": 0.1, + "test_size": 50, + "val_size": 30, + "seed": 42, + "auto": "medium" +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3ff6d3c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "clinical-impact-judge" +version = "0.1.0" +description = "LLM transcript alignment and evaluation toolkit" +readme = "README.md" +requires-python = ">=3.11" +dependencies = ["openai>=1.35.0", "num2words>=0.5.13", "python-dotenv>=1.2.1", "dspy>=3.0.4", "modaic>=0.4.1", "scikit-learn>=1.7.2", "vertexai>=1.71.1"] + +[project.optional-dependencies] +plot = ["matplotlib>=3.8"] +dev = ["pytest>=8.0"] + diff --git a/src/llm_judge/__init__.py b/src/llm_judge/__init__.py new file mode 100644 index 0000000..44398e5 --- /dev/null +++ b/src/llm_judge/__init__.py @@ -0,0 +1,10 @@ +"""Lightweight helpers for running DSPy judges with GEPA or MIPRO.""" + +__all__ = [ + "signatures", + "metrics", + "data", + "models", + "optimizers", + "eval", +] diff --git a/src/llm_judge/cli/__init__.py b/src/llm_judge/cli/__init__.py new file mode 100644 index 0000000..742a1a7 --- /dev/null +++ b/src/llm_judge/cli/__init__.py @@ -0,0 +1 @@ +# CLI entrypoints live here for `python -m llm_judge.cli.run_*`. diff --git a/src/llm_judge/cli/run_gepa.py b/src/llm_judge/cli/run_gepa.py new file mode 100644 index 0000000..bd958c8 --- /dev/null +++ b/src/llm_judge/cli/run_gepa.py @@ -0,0 +1,131 @@ +import argparse + +from dotenv import load_dotenv + +from src.llm_judge import data as data_utils +from src.llm_judge import eval as eval_utils +from src.llm_judge import metrics, signatures +from src.llm_judge.optimizers import get_optimizer +from src.llm_judge.providers import setup_models + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run GEPA optimization for clinical impact judge." + ) + parser.add_argument( + "--data-path", + type=str, + default="llm_judge/dataset/primock_data_final_outcomes.csv", + help="CSV file path.", + ) + parser.add_argument( + "--provider", + type=str, + default="openrouter", + choices=["gemini", "bedrock", "openrouter"], + ) + parser.add_argument( + "--task-model", type=str, default="meta-llama/llama-3.3-70b-instruct" + ) + parser.add_argument( + "--reflection-model", type=str, default="anthropic/claude-4-sonnet" + ) + parser.add_argument( + "--no-separate-reflection", + action="store_true", + help="Use task model for reflection too.", + ) + parser.add_argument("--output", type=str, default="clinical_judge_gepa.json") + parser.add_argument("--test-size", type=int, default=50) + parser.add_argument("--val-size", type=int, default=30) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--auto", type=str, default="medium", help="GEPA auto level: light|medium|heavy" + ) + return parser.parse_args() + + +def main(): + load_dotenv() + args = parse_args() + + if not args.data_path: + raise SystemExit("Please provide --data-path or set DATA_PATH.") + + separate_reflection = not args.no_separate_reflection + print("=" * 80) + print("DSPy Clinical Impact Judge - GEPA") + print("=" * 80) + + # Data + df = data_utils.load_dataset(args.data_path) + trainset, valset, testset = data_utils.build_splits( + df, test_size=args.test_size, val_size=args.val_size, random_state=args.seed + ) + print(f"Train: {len(trainset)} | Val: {len(valset)} | Test: {len(testset)}") + + # Models + reflection_model = args.reflection_model if separate_reflection else None + task_lm, reflection_lm = setup_models( + args.provider, task_model=args.task_model, reflection_model=reflection_model + ) + print(f"Connected to provider={args.provider} task_model={args.task_model}") + if reflection_lm: + print(f"Using separate reflection model: {args.reflection_model}") + + model_base = ( + args.provider + if args.provider == "openrouter" + or args.provider == "bedrock" + or args.provider == "ollama_chat" + else "vertex_ai/" + ) + config = signatures.ClinicalImpactJudgeConfig( + task_model=model_base + args.task_model, + reflection_model=model_base + args.reflection_model, + test_size=args.test_size, + val_size=args.val_size, + seed=args.seed, + auto=args.auto, + max_tokens=8000, + ) + judge = signatures.ClinicalImpactJudge(config) + judge.push_to_hub( + "jaredjoss123/clinical-impact-judge", + with_code=True, + commit_message="Unoptimized Clinical Impact Judge", + ) + + # Optimizer + optimizer = get_optimizer( + "gepa", + metric=metrics.gepa_feedback_metric, + reflection_lm=reflection_lm, + auto=args.auto, + reflection_minibatch_size=3, + candidate_selection_strategy="pareto", + skip_perfect_score=True, + track_stats=True, + seed=args.seed, + ) + + optimized_judge = optimizer.compile( + judge, + trainset=trainset, + valset=valset, + ) + optimized_judge.save(args.output) + optimized_judge.push_to_hub( + "jaredjoss123/clinical-impact-judge-gepa", + with_code=True, + commit_message="GEPA Optimized on Clinical Impact Judge", + ) + print(f"Saved optimized judge to {args.output}") + + # Evaluate on test + eval_utils.evaluate_judge(optimized_judge, testset, name="GEPA Optimized") + + +if __name__ == "__main__": + main() diff --git a/src/llm_judge/data.py b/src/llm_judge/data.py new file mode 100644 index 0000000..0db3cb3 --- /dev/null +++ b/src/llm_judge/data.py @@ -0,0 +1,51 @@ +from typing import List, Tuple + +import dspy +import pandas as pd +from sklearn.model_selection import train_test_split + + +def load_dataset(csv_path: str) -> pd.DataFrame: + """Load dataset and drop rows without final_outcome.""" + df = pd.read_csv(csv_path) + return df.dropna(subset=["final_outcome"]) + + +def split_dataset( + df: pd.DataFrame, + test_size: int = 50, + val_size: int = 30, + random_state: int = 42, +) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Split into train/val/test with fixed sizes and stratification.""" + y = df["final_outcome"] + train_val_df, test_df = train_test_split( + df, test_size=test_size, stratify=y, random_state=random_state + ) + train_df, val_df = train_test_split( + train_val_df, + test_size=val_size, + stratify=train_val_df["final_outcome"], + random_state=random_state, + ) + return train_df, val_df, test_df + + +def create_dspy_example(row) -> dspy.Example: + """Convert dataframe row to DSPy Example.""" + return dspy.Example( + ground_truth_conversation=str(row["fer_gt_context"]), + transcription_conversation=str(row["fer_hyp_context"]), + clinical_impact=str(int(row["final_outcome"])), + ).with_inputs("ground_truth_conversation", "transcription_conversation") + + +def build_splits( + df: pd.DataFrame, test_size: int = 50, val_size: int = 30, random_state: int = 42 +) -> Tuple[List[dspy.Example], List[dspy.Example], List[dspy.Example]]: + """Create DSPy train/val/test example lists from dataframe splits.""" + train_df, val_df, test_df = split_dataset(df, test_size, val_size, random_state) + trainset = [create_dspy_example(row) for _, row in train_df.iterrows()] + valset = [create_dspy_example(row) for _, row in val_df.iterrows()] + testset = [create_dspy_example(row) for _, row in test_df.iterrows()] + return trainset, valset, testset diff --git a/src/llm_judge/eval.py b/src/llm_judge/eval.py new file mode 100644 index 0000000..538f42d --- /dev/null +++ b/src/llm_judge/eval.py @@ -0,0 +1,62 @@ +import pandas as pd +from sklearn.metrics import classification_report, cohen_kappa_score, confusion_matrix + +from .metrics import parse_label + + +def evaluate_judge(judge, testset, name="Judge"): + """Evaluate a judge on a testset and print metrics.""" + print("\n" + "=" * 80) + print(f"EVALUATING: {name}") + print("=" * 80) + + results = [] + for idx, example in enumerate(testset): + try: + prediction = judge( + ground_truth_conversation=example.ground_truth_conversation, + transcription_conversation=example.transcription_conversation, + ) + pred_label = parse_label(prediction.clinical_impact) + true_label = int(example.clinical_impact) + if pred_label is not None: + results.append({"true_label": true_label, "pred_label": pred_label}) + except Exception as exc: # pragma: no cover - runtime guardrail + print(f"Error on example {idx}: {exc}") + continue + + if not results: + return None, 0, 0 + + results_df = pd.DataFrame(results) + true_labels = results_df["true_label"].values + pred_labels = results_df["pred_label"].values + + accuracy = (true_labels == pred_labels).mean() * 100 + kappa = cohen_kappa_score(true_labels, pred_labels) + + print(f"\nAccuracy: {accuracy:.2f}%") + print(f"Cohen's Kappa: {kappa:.3f}") + print("\nClassification Report:") + print( + classification_report( + true_labels, + pred_labels, + target_names=["0 (No impact)", "1 (Minimal)", "2 (Significant)"], + zero_division=0, + ) + ) + print("\nConfusion Matrix:") + cm = confusion_matrix(true_labels, pred_labels) + print(" Predicted") + print(" 0 1 2") + for i, row_label in enumerate(["Actual 0", "Actual 1", "Actual 2"]): + print(f"{row_label:10s} {cm[i][0]:4d} {cm[i][1]:4d} {cm[i][2]:4d}") + + for class_label in [0, 1, 2]: + mask = true_labels == class_label + if mask.sum() > 0: + recall = (true_labels[mask] == pred_labels[mask]).mean() * 100 + print(f"Class {class_label} recall: {recall:.1f}%") + + return results_df, accuracy, kappa diff --git a/src/llm_judge/metrics.py b/src/llm_judge/metrics.py new file mode 100644 index 0000000..d25fdb3 --- /dev/null +++ b/src/llm_judge/metrics.py @@ -0,0 +1,108 @@ +import json +import re +from typing import Optional + +import dspy + +COST_MATRIX = [ + [1.2, 0.3, -1.0], + [0.3, 1.5, 0.5], + [-1.2, 0.4, 1.5], +] + + +def parse_label(label_str: str) -> Optional[int]: + try: + label_str = str(label_str).strip() + if label_str in {"0", "1", "2"}: + return int(label_str) + + json_match = re.search(r"\{.*\}", label_str, re.DOTALL) + if json_match: + obj = json.loads(json_match.group(0)) + val = obj.get("clinical_impact") + if val in [0, 1, 2] or str(val) in "012": + return int(val) + + num_match = re.search(r"\b([0-2])\b", label_str) + if num_match: + return int(num_match.group(1)) + except Exception: + return None + return None + + +def gepa_feedback_metric( + example, prediction, trace=None, pred_name=None, pred_trace=None +): + true_label = int(example.clinical_impact) + pred_label = parse_label(prediction.clinical_impact) + + if pred_label is None: + feedback = ( + f"PARSING ERROR: The model failed to output a valid class (0, 1, or 2). " + f"Raw output: '{prediction.clinical_impact}'. " + f"The model MUST return ONLY the number 0, 1, or 2 as specified in the output field description. " + f"Consider emphasizing in the instructions: output format must be strictly a single digit." + ) + return dspy.Prediction(score=-2.0, feedback=feedback) + + score = COST_MATRIX[true_label][pred_label] + + # Generate detailed feedback based on the prediction outcome + if pred_label == true_label: + class_names = {0: "No impact", 1: "Minimal impact", 2: "Significant impact"} + feedback = ( + f"CORRECT: Correctly identified as Class {true_label} ({class_names[true_label]}). " + f"The model's reasoning was appropriate for this classification. " + f"Continue using similar reasoning patterns for this type of case." + ) + else: + if true_label == 0 and pred_label > 0: + feedback = ( + f"OVER-CLASSIFICATION: Predicted Class {pred_label} but should be Class 0 (No impact). " + f"The transcription differences are cosmetic (punctuation, capitalization, filler words) " + f"and do NOT affect clinical meaning. The model should be MORE LENIENT with minor differences " + f"and focus ONLY on content that affects diagnosis or treatment decisions." + ) + elif true_label == 1 and pred_label == 0: + feedback = ( + f"UNDER-CLASSIFICATION: Predicted Class 0 but should be Class 1 (Minimal impact). " + f"While not critical to diagnosis/treatment, some clinically relevant information was " + f"missing or changed. The model should be MORE SENSITIVE to information changes, " + f"even if they don't directly affect critical decisions." + ) + elif true_label == 1 and pred_label == 2: + feedback = ( + f"OVER-CLASSIFICATION: Predicted Class 2 but should be Class 1 (Minimal impact). " + f"The information changes are not critical enough to affect diagnosis or patient safety. " + f"Reserve Class 2 ONLY for errors that COULD directly affect diagnosis, treatment, or safety. " + f"The model should distinguish between 'some information missing' vs 'critical information missing'." + ) + elif true_label == 2 and pred_label < 2: + feedback = ( + f"CRITICAL MISS: Predicted Class {pred_label} but should be Class 2 (Significant impact). " + f"This is a HIGH-PRIORITY error. The transcription contained missing/incorrect information " + f"that COULD affect diagnosis, treatment, or patient safety. The model MUST be MORE SENSITIVE " + f"to clinically critical information like symptoms, medications, measurements, or diagnoses. " + f"Look for: changed medical terms, missing symptoms, altered measurements, or omitted diagnoses." + ) + else: + feedback = ( + f"MAJOR ERROR: Predicted Class {pred_label} but should be Class {true_label}. " + f"This is a large classification error spanning 2 severity levels. " + f"The model needs to fundamentally reassess its criteria for clinical impact. " + f"Review the distinction between cosmetic changes, information changes, and critical errors." + ) + + feedback += f" [True: {true_label}, Predicted: {pred_label}]" + + return dspy.Prediction(score=score, feedback=feedback) + + +def simple_metric(example, prediction, trace=None): + true_label = int(example.clinical_impact) + pred_label = parse_label(prediction.clinical_impact) + if pred_label is None: + return -2.0 + return COST_MATRIX[true_label][pred_label] diff --git a/src/llm_judge/optimizers/__init__.py b/src/llm_judge/optimizers/__init__.py new file mode 100644 index 0000000..000c02e --- /dev/null +++ b/src/llm_judge/optimizers/__init__.py @@ -0,0 +1,3 @@ +from .factory import get_optimizer + +__all__ = ["get_optimizer"] diff --git a/src/llm_judge/optimizers/factory.py b/src/llm_judge/optimizers/factory.py new file mode 100644 index 0000000..fad41d7 --- /dev/null +++ b/src/llm_judge/optimizers/factory.py @@ -0,0 +1,14 @@ +from typing import Any, Optional + +from .gepa import build_gepa +from .mipro import build_mipro + + +def get_optimizer(name: str, metric, reflection_lm: Optional[Any] = None, **kwargs): + """Return a configured optimizer by name.""" + name = name.lower() + if name == "gepa": + return build_gepa(metric=metric, reflection_lm=reflection_lm, **kwargs) + if name in {"mipro", "miprov2"}: + return build_mipro(metric=metric, **kwargs) + raise ValueError(f"Unsupported optimizer: {name}") diff --git a/src/llm_judge/optimizers/gepa.py b/src/llm_judge/optimizers/gepa.py new file mode 100644 index 0000000..c83b22e --- /dev/null +++ b/src/llm_judge/optimizers/gepa.py @@ -0,0 +1,8 @@ +from dspy.teleprompt import GEPA + + +def build_gepa(metric, reflection_lm=None, **kwargs): + """Construct a GEPA optimizer.""" + if reflection_lm is not None: + kwargs["reflection_lm"] = reflection_lm + return GEPA(metric=metric, **kwargs) diff --git a/src/llm_judge/optimizers/mipro.py b/src/llm_judge/optimizers/mipro.py new file mode 100644 index 0000000..1ff4e7e --- /dev/null +++ b/src/llm_judge/optimizers/mipro.py @@ -0,0 +1,7 @@ +from dspy.teleprompt import MIPROv2 + + +def build_mipro(metric, **kwargs): + """Construct a MIPROv2 optimizer.""" + allowed = {k: v for k, v in kwargs.items() if k in {"auto", "seed"}} + return MIPROv2(metric=metric, **allowed) diff --git a/src/llm_judge/providers/__init__.py b/src/llm_judge/providers/__init__.py new file mode 100644 index 0000000..5369c0a --- /dev/null +++ b/src/llm_judge/providers/__init__.py @@ -0,0 +1,3 @@ +from .factory import setup_models + +__all__ = ["setup_models"] diff --git a/src/llm_judge/providers/bedrock.py b/src/llm_judge/providers/bedrock.py new file mode 100644 index 0000000..1ec7158 --- /dev/null +++ b/src/llm_judge/providers/bedrock.py @@ -0,0 +1,28 @@ +from typing import Optional, Tuple +import os +import dspy +from dotenv import load_dotenv + +load_dotenv() + + +def init_bedrock( + task_model: str, + reflection_model: Optional[str] = None, + region: str = "us-east-1", + max_tokens: int = 1000, +) -> Tuple[dspy.LM, Optional[dspy.LM]]: + """Configure DSPy to use AWS Bedrock.""" + task_lm = dspy.LM( + f"bedrock/{task_model}", + region_name=(os.getenv("AWS_REGION", region)), + max_tokens=max_tokens, + ) + dspy.settings.configure(lm=task_lm) + + reflection_lm = None + if reflection_model: + reflection_lm = dspy.LM( + f"bedrock/{reflection_model}", region_name=region, max_tokens=max_tokens + ) + return task_lm, reflection_lm diff --git a/src/llm_judge/providers/factory.py b/src/llm_judge/providers/factory.py new file mode 100644 index 0000000..1a5529a --- /dev/null +++ b/src/llm_judge/providers/factory.py @@ -0,0 +1,25 @@ +from typing import Optional + +from .bedrock import init_bedrock +from .gemini import init_gemini +from .openrouter import init_openrouter +from .ollama_chat import init_ollama_chat + + +def setup_models( + provider: str, + task_model: str, + reflection_model: Optional[str] = None, + **kwargs, +): + """Initialize task/reflection LMs and configure DSPy.""" + provider = provider.lower() + if provider == "gemini": + return init_gemini(task_model, reflection_model=reflection_model, **kwargs) + if provider == "bedrock": + return init_bedrock(task_model, reflection_model=reflection_model, **kwargs) + if provider == "openrouter": + return init_openrouter(task_model, reflection_model=reflection_model, **kwargs) + if provider == "ollam_chat": + return init_ollama_chat(task_model, reflection_model=reflection_model, **kwargs) + raise ValueError(f"Unsupported provider: {provider}") diff --git a/src/llm_judge/providers/gemini.py b/src/llm_judge/providers/gemini.py new file mode 100644 index 0000000..4dedd26 --- /dev/null +++ b/src/llm_judge/providers/gemini.py @@ -0,0 +1,24 @@ +import os +from typing import Optional, Tuple + +import dspy +import vertexai + + +def init_gemini( + task_model: str, + reflection_model: Optional[str] = None, + max_tokens: int = 8000, +) -> Tuple[dspy.LM, Optional[dspy.LM]]: + """Configure DSPy to use Gemini via Vertex AI.""" + project = os.getenv("GCP_PROJECT_ID", "your-project-id") + location = os.getenv("GCP_LOCATION", "us-central1") + vertexai.init(project=project, location=location) + + task_lm = dspy.LM(f"vertex_ai/{task_model}", max_tokens=max_tokens) + dspy.settings.configure(lm=task_lm) + + reflection_lm = None + if reflection_model: + reflection_lm = dspy.LM(f"vertex_ai/{reflection_model}", max_tokens=max_tokens) + return task_lm, reflection_lm diff --git a/src/llm_judge/providers/ollama_chat.py b/src/llm_judge/providers/ollama_chat.py new file mode 100644 index 0000000..edd717c --- /dev/null +++ b/src/llm_judge/providers/ollama_chat.py @@ -0,0 +1,28 @@ +from typing import Optional, Tuple + +import dspy + + +def init_ollama_chat( + task_model: str, + reflection_model: Optional[str] = None, + max_tokens: int = 1000, +) -> Tuple[dspy.LM, Optional[dspy.LM]]: + """Configure DSPy to use Ollama Chat.""" + task_lm = dspy.LM( + f"ollama_chat/{task_model}", + api_base="http://localhost:11434", + api_key="", + max_tokens=max_tokens, + ) + dspy.settings.configure(lm=task_lm) + + reflection_lm = None + if reflection_model: + reflection_lm = dspy.LM( + f"ollama_chat/{reflection_model}", + api_base="http://localhost:11434", + api_key="", + max_tokens=max_tokens, + ) + return task_lm, reflection_lm diff --git a/src/llm_judge/providers/openrouter.py b/src/llm_judge/providers/openrouter.py new file mode 100644 index 0000000..a8ba769 --- /dev/null +++ b/src/llm_judge/providers/openrouter.py @@ -0,0 +1,38 @@ +import os +from typing import Optional, Tuple + +import dspy + + +def init_openrouter( + task_model: str, + reflection_model: Optional[str] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + max_tokens: int = 4000, +) -> Tuple[dspy.LM, Optional[dspy.LM]]: + """Configure DSPy to use OpenRouter.""" + base_url = base_url or os.getenv( + "OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1" + ) + api_key = api_key or os.getenv("OPENROUTER_API_KEY") + if not api_key: + raise ValueError("OPENROUTER_API_KEY is required for OpenRouter provider") + + task_lm = dspy.LM( + f"openrouter/{task_model}", + api_base=base_url, + api_key=api_key, + max_tokens=max_tokens, + ) + dspy.settings.configure(lm=task_lm) + + reflection_lm = None + if reflection_model: + reflection_lm = dspy.LM( + f"openrouter/{reflection_model}", + api_base=base_url, + api_key=api_key, + max_tokens=max_tokens, + ) + return task_lm, reflection_lm diff --git a/src/llm_judge/signatures.py b/src/llm_judge/signatures.py new file mode 100644 index 0000000..7136703 --- /dev/null +++ b/src/llm_judge/signatures.py @@ -0,0 +1,58 @@ +import dspy +from modaic import PrecompiledAgent, PrecompiledConfig +from typing import Optional + + +class ClinicalImpactAssessment(dspy.Signature): + """Assess the clinical impact of transcription errors in medical conversations. + + Compare the ground truth conversation with the transcription conversation and determine + if errors would affect patient care. Focus on THREE distinct severity levels. + """ + + ground_truth_conversation = dspy.InputField() + transcription_conversation = dspy.InputField() + reasoning = dspy.OutputField( + desc="Brief clinical justification for the assessment." + ) + clinical_impact = dspy.OutputField( + desc="""Clinical impact class (return ONLY the number): + 0 = No impact: cosmetic differences only (punctuation, capitalization, filler words) + 1 = Minimal impact: some information missing/changed but NOT critical to diagnosis or treatment decisions + 2 = Significant impact: missing/incorrect information that COULD affect diagnosis, treatment, or patient safety + Return ONLY: 0, 1, or 2""" + ) + + +class ClinicalImpactJudgeConfig(PrecompiledConfig): + task_model: str = "gemini-2.5-pro" + reflection_model: Optional[str] = None # for GEPA runs + max_tokens: int = 8000 + temperature: float = 0.1 + test_size: Optional[int] = 50 + val_size: Optional[int] = 30 + seed: Optional[int] = 42 + auto: Optional[str] = "medium" + + +class ClinicalImpactJudge(PrecompiledAgent): + """LLM Judge for assessing clinical impact.""" + + config: ClinicalImpactJudgeConfig + + def __init__(self, config: ClinicalImpactJudgeConfig, **kwargs): + super().__init__(config, **kwargs) + self.assess = dspy.ChainOfThought(ClinicalImpactAssessment) + self.assess.set_lm( + dspy.LM( + config.task_model, + max_tokens=config.max_tokens, + temperature=config.temperature, + ) + ) + + def forward(self, ground_truth_conversation, transcription_conversation): + return self.assess( + ground_truth_conversation=ground_truth_conversation, + transcription_conversation=transcription_conversation, + )