GEPA Optimized on Clinical Impact Judge

This commit is contained in:
2025-11-30 07:08:00 -05:00
parent 50f18698ce
commit 347ef3b3c0
22 changed files with 743 additions and 1 deletions

View File

@@ -1,2 +1,73 @@
# clinical-impact-judge-gepa # WER is Unaware: Assessing How ASR Errors Distort Clinical Understanding in Patient-Facing Dialogue
This repository hosts the code, models, and datasets accompanying the paper. The work investigates how Automatic Speech Recognition (ASR) errors **distort clinical meaning in patient-facing dialogue** — and shows that traditional metrics like Word Error Rate (WER) fail to capture real clinical risk. The project includes scripts for aligning ground-truth utterances to ASR-generated utterances using an **LLM-based semantic aligner**, and optimizing an **LLM-as-a-Judge for clinical impact assessment** using GEPA through DSPy.
## 📝 Abstract
![WER is Unaware Overview](overview.png)
As Automatic Speech Recognition (ASR) is increasingly deployed in clinical dialogue, standard evaluations still rely heavily on Word Error Rate (WER). This paper challenges that standard, investigating whether WER or other common metrics correlate with the clinical impact of transcription errors. We establish a gold-standard benchmark by having expert clinicians compare ground-truth utterances to their ASR-generated counterparts, labeling the clinical impact of any discrepancies found in two distinct doctor-patient dialogue datasets. Our analysis reveals that WER and a comprehensive suite of existing metrics correlate poorly with the clinician-assigned risk labels (No, Minimal, or Significant Impact). To bridge this evaluation gap, we introduce an LLM-as-a-Judge, programmatically optimized using GEPA to replicate expert clinical assessment. The optimized judge (Gemini-2.5-Pro) achieves human-comparable performance, obtaining 90% accuracy and a strong Cohen's κ of 0.816. This work provides a validated, automated framework for moving ASR evaluation beyond simple textual fidelity to a necessary, scalable assessment of safety in clinical dialogue.
## 🔍 Overview
We introduce (available here):
- Clinician-annotated clinical-impact dataset: `llm_judge/dataset/primock_data_final_outcomes.csv`
- Semantic LLM-based aligner: `alignment/aligner/` (see `alignment/README.md` for usage)
- LLM-as-a-Judge optimized with GEPA/MIPRO: `llm_judge/` (artifacts in `llm_judge/results/`)
- Evaluations of ASR metrics (code under `alignment/scripts/` and `alignment/results/`)
## 🛠️ Environment Setup
- Install Python 3.10+ and `uv` (recommended): https://github.com/astral-sh/uv
- Install dependencies: `uv sync`
- Environment:
- OpenRouter (default for LLM calls): `OPENROUTER_API_KEY` (required), `OPENROUTER_MODEL` optional
- Gemini (optional): `GCP_PROJECT_ID`, `GCP_LOCATION`
- Bedrock (optional): `AWS_REGION`
- Example: run aligner evaluation
```bash
uv run python alignment/scripts/run_evaluation.py --case-id sample --asr-system demo
```
- Example: run judge (GEPA)
```bash
uv run python -m llm_judge.cli.run_gepa \
--data-path llm_judge/dataset/primock_data_final_outcomes.csv \
--provider openrouter \
--task-model meta-llama/llama-3.3-70b-instruct \
--reflection-model anthropic/claude-4-sonnet \
--output llm_judge/results/clinical_judge_gepa.json
```
## 📁 Folder Structure
- `alignment/` — semantic alignment toolkit (aligner code, scripts, sample data, sample results).
- `llm_judge/` — clinical impact judge (signatures, metrics, providers, optimizers, CLI, bundled dataset, saved judges).
### Important Files
- `alignment/data/` — example ASR transcripts and ground-truth alignments.
- `alignment/results/` — sample alignment evaluations.
- `llm_judge/dataset/` — clinical-impact dataset.
- `llm_judge/results/` — optimized judges (GEPA, MIPROv2).
## 📦 Coming Soon
- Additional dataset metadata and documentation
- Evaluations of 20+ ASR metrics, showing their poor correlation with clinical safety
## 📄 Paper
Preprint available on arXiv: https://arxiv.org/abs/2511.16544
## 📚 Citation
```bibtex
@misc{ellis2025werunawareassessingasr,
title={WER is Unaware: Assessing How ASR Errors Distort Clinical Understanding in Patient Facing Dialogue},
author={Zachary Ellis and Jared Joselowitz and Yash Deo and Yajie He and Anna Kalygina and Aisling Higham and Mana Rahimzadeh and Yan Jia and Ibrahim Habli and Ernest Lim},
year={2025},
eprint={2511.16544},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2511.16544},
}
```

46
agent.json Normal file
View File

@@ -0,0 +1,46 @@
{
"assess.predict": {
"traces": [],
"train": [],
"demos": [],
"signature": {
"instructions": "Assess the clinical impact of transcription errors in medical conversations by comparing ground truth and transcribed conversations. Classify errors into THREE severity levels based on their potential to affect patient care:\n\n**Classification Levels:**\n- **Class 0 (No Impact)**: Cosmetic differences that don't affect clinical meaning\n- **Class 1 (Moderate Impact)**: Errors that could cause confusion but unlikely to cause harm\n- **Class 2 (Significant Impact)**: Critical errors that could lead to misdiagnosis, wrong treatment, or patient safety risks\n\n**CRITICAL: Use BALANCED SENSITIVITY - Don't be overly lenient, but reserve Class 2 for truly critical errors**\n\n**Classification Guidelines:**\n\n**Class 0 - No clinical impact (for truly cosmetic changes):**\n- Punctuation and capitalization changes (periods, commas, capitalization)\n- Filler words (um, uh, like) being added or removed\n- Minor grammatical variations that preserve exact meaning\n- Patient names or identifying information (these don't affect clinical decisions)\n- Incomplete sentences or missing fragments that don't contain medical information\n- Nonsensical transcriptions of unclear speech that don't relate to medical content\n- Missing or altered non-medical conversational content (activities, TV shows, casual remarks)\n- Minor word reordering that preserves identical clinical meaning (e.g., \"one time, one thing at a time\" vs \"one time thing at the time\")\n\n**Class 1 - Moderate clinical relevance (use when clinical context is somewhat affected):**\n- Minor changes to symptom descriptions that don't fundamentally alter clinical understanding\n- Missing contextual details that provide clinical nuance (e.g., \"than usual\" being omitted from work stress descriptions)\n- Timeline discrepancies that could cause mild confusion but don't affect treatment decisions\n- Non-critical medical history details being altered that don't impact current care\n- Minor medication name variations that are still recognizable to clinicians\n- Partial loss of clinical context that doesn't affect diagnosis but may impact clinical understanding\n\n**Class 2 - ONLY for content critically affecting diagnosis/treatment:**\n- Medication status completely misrepresented (e.g., \"don't take medicine\" vs \"need to get medicine\")\n- Specific medication names being completely changed to different medications\n- Critical contraception methods being misidentified with different clinical implications\n- Dosages, frequencies, or administration instructions being significantly altered\n- Symptom severity being dramatically changed (mild to severe or vice versa)\n- Allergies or contraindications being missed, added, or altered\n- Critical family history being lost or fundamentally changed\n- Treatment plans or medical advice being misrepresented\n\n**Decision Framework:**\n1. First ask: \"Does this difference involve actual medical content (symptoms, medications, treatments, medical history, clinical context)?\"\n - If NO \u2192 Class 0\n2. If YES, ask: \"Is any clinically relevant information changed, missing, or added, even if not critical?\"\n - If NO change to clinical information \u2192 Class 0\n - If minor clinical context affected but diagnosis/treatment unlikely to change \u2192 Class 1\n - If major impact on diagnosis/treatment decisions \u2192 Class 2\n\n**Key Sensitivity Points:**\n- Be more sensitive to missing clinical context, even if not diagnosis-critical\n- Words like \"usual,\" temporal qualifiers, and descriptive modifiers can carry clinical significance\n- Don't be overly lenient with Class 0 - use Class 1 when clinical nuance is affected\n- Reserve Class 2 only for errors that could directly lead to different clinical decisions\n\n**Key Examples:**\n- Missing \"than usual\" in work stress context = Class 1 (clinical context affected)\n- \"One time, one thing at a time\" vs \"one time thing at the time\" = Class 0 (identical meaning)\n- Complete misrepresentation of medication status = Class 2 (critical error)\n\nProvide your assessment in two parts:\n1. **reasoning**: Explain the key differences and their clinical significance\n2. **clinical_impact**: Assign a number (0, 1, or 2) based on the classification above",
"fields": [
{
"prefix": "Ground Truth Conversation:",
"description": "${ground_truth_conversation}"
},
{
"prefix": "Transcription Conversation:",
"description": "${transcription_conversation}"
},
{
"prefix": "Reasoning:",
"description": "Brief clinical justification for the assessment."
},
{
"prefix": "Clinical Impact:",
"description": "Clinical impact class (return ONLY the number):\n 0 = No impact: cosmetic differences only (punctuation, capitalization, filler words)\n 1 = Minimal impact: some information missing/changed but NOT critical to diagnosis or treatment decisions \n 2 = Significant impact: missing/incorrect information that COULD affect diagnosis, treatment, or patient safety\n Return ONLY: 0, 1, or 2"
}
]
},
"lm": {
"model": "openrouter/google/gemini-2.5-pro",
"model_type": "chat",
"cache": true,
"num_retries": 3,
"finetuning_model": null,
"launch_kwargs": {},
"train_kwargs": {},
"temperature": 0.1,
"max_tokens": 8000
}
},
"metadata": {
"dependency_versions": {
"python": "3.13",
"dspy": "3.0.4",
"cloudpickle": "3.1"
}
}
}

4
auto_classes.json Normal file
View File

@@ -0,0 +1,4 @@
{
"AutoConfig": "src.llm_judge.signatures.ClinicalImpactJudgeConfig",
"AutoAgent": "src.llm_judge.signatures.ClinicalImpactJudge"
}

10
config.json Normal file
View File

@@ -0,0 +1,10 @@
{
"task_model": "openrouter/google/gemini-2.5-pro",
"reflection_model": "openrouter/anthropic/claude-sonnet-4",
"max_tokens": 8000,
"temperature": 0.1,
"test_size": 50,
"val_size": 30,
"seed": 42,
"auto": "medium"
}

12
pyproject.toml Normal file
View File

@@ -0,0 +1,12 @@
[project]
name = "clinical-impact-judge-gepa"
version = "0.1.0"
description = "LLM transcript alignment and evaluation toolkit"
readme = "README.md"
requires-python = ">=3.11"
dependencies = ["openai>=1.35.0", "num2words>=0.5.13", "python-dotenv>=1.2.1", "dspy>=3.0.4", "modaic>=0.4.1", "scikit-learn>=1.7.2", "vertexai>=1.71.1"]
[project.optional-dependencies]
plot = ["matplotlib>=3.8"]
dev = ["pytest>=8.0"]

10
src/llm_judge/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""Lightweight helpers for running DSPy judges with GEPA or MIPRO."""
__all__ = [
"signatures",
"metrics",
"data",
"models",
"optimizers",
"eval",
]

View File

@@ -0,0 +1 @@
# CLI entrypoints live here for `python -m llm_judge.cli.run_*`.

View File

@@ -0,0 +1,131 @@
import argparse
from dotenv import load_dotenv
from src.llm_judge import data as data_utils
from src.llm_judge import eval as eval_utils
from src.llm_judge import metrics, signatures
from src.llm_judge.optimizers import get_optimizer
from src.llm_judge.providers import setup_models
def parse_args():
parser = argparse.ArgumentParser(
description="Run GEPA optimization for clinical impact judge."
)
parser.add_argument(
"--data-path",
type=str,
default="llm_judge/dataset/primock_data_final_outcomes.csv",
help="CSV file path.",
)
parser.add_argument(
"--provider",
type=str,
default="openrouter",
choices=["gemini", "bedrock", "openrouter"],
)
parser.add_argument(
"--task-model", type=str, default="meta-llama/llama-3.3-70b-instruct"
)
parser.add_argument(
"--reflection-model", type=str, default="anthropic/claude-sonnet-4"
)
parser.add_argument(
"--no-separate-reflection",
action="store_true",
help="Use task model for reflection too.",
)
parser.add_argument("--output", type=str, default="clinical_judge_gepa.json")
parser.add_argument("--test-size", type=int, default=50)
parser.add_argument("--val-size", type=int, default=30)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--auto", type=str, default="medium", help="GEPA auto level: light|medium|heavy"
)
return parser.parse_args()
def main():
load_dotenv()
args = parse_args()
if not args.data_path:
raise SystemExit("Please provide --data-path or set DATA_PATH.")
separate_reflection = not args.no_separate_reflection
print("=" * 80)
print("DSPy Clinical Impact Judge - GEPA")
print("=" * 80)
# Data
df = data_utils.load_dataset(args.data_path)
trainset, valset, testset = data_utils.build_splits(
df, test_size=args.test_size, val_size=args.val_size, random_state=args.seed
)
print(f"Train: {len(trainset)} | Val: {len(valset)} | Test: {len(testset)}")
# Models
reflection_model = args.reflection_model if separate_reflection else None
task_lm, reflection_lm = setup_models(
args.provider, task_model=args.task_model, reflection_model=reflection_model
)
print(f"Connected to provider={args.provider} task_model={args.task_model}")
if reflection_lm:
print(f"Using separate reflection model: {args.reflection_model}")
model_base = (
f"{args.provider}/"
if args.provider == "openrouter"
or args.provider == "bedrock"
or args.provider == "ollama_chat"
else "vertex_ai/"
)
config = signatures.ClinicalImpactJudgeConfig(
task_model=model_base + args.task_model,
reflection_model=model_base + args.reflection_model,
test_size=args.test_size,
val_size=args.val_size,
seed=args.seed,
auto=args.auto,
max_tokens=8000,
)
judge = signatures.ClinicalImpactJudge(config)
judge.push_to_hub(
"jaredjoss123/clinical-impact-judge",
with_code=True,
commit_message="Unoptimized Clinical Impact Judge",
)
# Optimizer
optimizer = get_optimizer(
"gepa",
metric=metrics.gepa_feedback_metric,
reflection_lm=reflection_lm,
auto=args.auto,
reflection_minibatch_size=3,
candidate_selection_strategy="pareto",
skip_perfect_score=True,
track_stats=True,
seed=args.seed,
)
optimized_judge = optimizer.compile(
judge,
trainset=trainset,
valset=valset,
)
optimized_judge.save(args.output)
optimized_judge.push_to_hub(
"jaredjoss123/clinical-impact-judge-gepa",
with_code=True,
commit_message="GEPA Optimized on Clinical Impact Judge",
)
print(f"Saved optimized judge to {args.output}")
# Evaluate on test
eval_utils.evaluate_judge(optimized_judge, testset, name="GEPA Optimized")
if __name__ == "__main__":
main()

51
src/llm_judge/data.py Normal file
View File

@@ -0,0 +1,51 @@
from typing import List, Tuple
import dspy
import pandas as pd
from sklearn.model_selection import train_test_split
def load_dataset(csv_path: str) -> pd.DataFrame:
"""Load dataset and drop rows without final_outcome."""
df = pd.read_csv(csv_path)
return df.dropna(subset=["final_outcome"])
def split_dataset(
df: pd.DataFrame,
test_size: int = 50,
val_size: int = 30,
random_state: int = 42,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Split into train/val/test with fixed sizes and stratification."""
y = df["final_outcome"]
train_val_df, test_df = train_test_split(
df, test_size=test_size, stratify=y, random_state=random_state
)
train_df, val_df = train_test_split(
train_val_df,
test_size=val_size,
stratify=train_val_df["final_outcome"],
random_state=random_state,
)
return train_df, val_df, test_df
def create_dspy_example(row) -> dspy.Example:
"""Convert dataframe row to DSPy Example."""
return dspy.Example(
ground_truth_conversation=str(row["fer_gt_context"]),
transcription_conversation=str(row["fer_hyp_context"]),
clinical_impact=str(int(row["final_outcome"])),
).with_inputs("ground_truth_conversation", "transcription_conversation")
def build_splits(
df: pd.DataFrame, test_size: int = 50, val_size: int = 30, random_state: int = 42
) -> Tuple[List[dspy.Example], List[dspy.Example], List[dspy.Example]]:
"""Create DSPy train/val/test example lists from dataframe splits."""
train_df, val_df, test_df = split_dataset(df, test_size, val_size, random_state)
trainset = [create_dspy_example(row) for _, row in train_df.iterrows()]
valset = [create_dspy_example(row) for _, row in val_df.iterrows()]
testset = [create_dspy_example(row) for _, row in test_df.iterrows()]
return trainset, valset, testset

62
src/llm_judge/eval.py Normal file
View File

@@ -0,0 +1,62 @@
import pandas as pd
from sklearn.metrics import classification_report, cohen_kappa_score, confusion_matrix
from .metrics import parse_label
def evaluate_judge(judge, testset, name="Judge"):
"""Evaluate a judge on a testset and print metrics."""
print("\n" + "=" * 80)
print(f"EVALUATING: {name}")
print("=" * 80)
results = []
for idx, example in enumerate(testset):
try:
prediction = judge(
ground_truth_conversation=example.ground_truth_conversation,
transcription_conversation=example.transcription_conversation,
)
pred_label = parse_label(prediction.clinical_impact)
true_label = int(example.clinical_impact)
if pred_label is not None:
results.append({"true_label": true_label, "pred_label": pred_label})
except Exception as exc: # pragma: no cover - runtime guardrail
print(f"Error on example {idx}: {exc}")
continue
if not results:
return None, 0, 0
results_df = pd.DataFrame(results)
true_labels = results_df["true_label"].values
pred_labels = results_df["pred_label"].values
accuracy = (true_labels == pred_labels).mean() * 100
kappa = cohen_kappa_score(true_labels, pred_labels)
print(f"\nAccuracy: {accuracy:.2f}%")
print(f"Cohen's Kappa: {kappa:.3f}")
print("\nClassification Report:")
print(
classification_report(
true_labels,
pred_labels,
target_names=["0 (No impact)", "1 (Minimal)", "2 (Significant)"],
zero_division=0,
)
)
print("\nConfusion Matrix:")
cm = confusion_matrix(true_labels, pred_labels)
print(" Predicted")
print(" 0 1 2")
for i, row_label in enumerate(["Actual 0", "Actual 1", "Actual 2"]):
print(f"{row_label:10s} {cm[i][0]:4d} {cm[i][1]:4d} {cm[i][2]:4d}")
for class_label in [0, 1, 2]:
mask = true_labels == class_label
if mask.sum() > 0:
recall = (true_labels[mask] == pred_labels[mask]).mean() * 100
print(f"Class {class_label} recall: {recall:.1f}%")
return results_df, accuracy, kappa

108
src/llm_judge/metrics.py Normal file
View File

@@ -0,0 +1,108 @@
import json
import re
from typing import Optional
import dspy
COST_MATRIX = [
[1.2, 0.3, -1.0],
[0.3, 1.5, 0.5],
[-1.2, 0.4, 1.5],
]
def parse_label(label_str: str) -> Optional[int]:
try:
label_str = str(label_str).strip()
if label_str in {"0", "1", "2"}:
return int(label_str)
json_match = re.search(r"\{.*\}", label_str, re.DOTALL)
if json_match:
obj = json.loads(json_match.group(0))
val = obj.get("clinical_impact")
if val in [0, 1, 2] or str(val) in "012":
return int(val)
num_match = re.search(r"\b([0-2])\b", label_str)
if num_match:
return int(num_match.group(1))
except Exception:
return None
return None
def gepa_feedback_metric(
example, prediction, trace=None, pred_name=None, pred_trace=None
):
true_label = int(example.clinical_impact)
pred_label = parse_label(prediction.clinical_impact)
if pred_label is None:
feedback = (
f"PARSING ERROR: The model failed to output a valid class (0, 1, or 2). "
f"Raw output: '{prediction.clinical_impact}'. "
f"The model MUST return ONLY the number 0, 1, or 2 as specified in the output field description. "
f"Consider emphasizing in the instructions: output format must be strictly a single digit."
)
return dspy.Prediction(score=-2.0, feedback=feedback)
score = COST_MATRIX[true_label][pred_label]
# Generate detailed feedback based on the prediction outcome
if pred_label == true_label:
class_names = {0: "No impact", 1: "Minimal impact", 2: "Significant impact"}
feedback = (
f"CORRECT: Correctly identified as Class {true_label} ({class_names[true_label]}). "
f"The model's reasoning was appropriate for this classification. "
f"Continue using similar reasoning patterns for this type of case."
)
else:
if true_label == 0 and pred_label > 0:
feedback = (
f"OVER-CLASSIFICATION: Predicted Class {pred_label} but should be Class 0 (No impact). "
f"The transcription differences are cosmetic (punctuation, capitalization, filler words) "
f"and do NOT affect clinical meaning. The model should be MORE LENIENT with minor differences "
f"and focus ONLY on content that affects diagnosis or treatment decisions."
)
elif true_label == 1 and pred_label == 0:
feedback = (
f"UNDER-CLASSIFICATION: Predicted Class 0 but should be Class 1 (Minimal impact). "
f"While not critical to diagnosis/treatment, some clinically relevant information was "
f"missing or changed. The model should be MORE SENSITIVE to information changes, "
f"even if they don't directly affect critical decisions."
)
elif true_label == 1 and pred_label == 2:
feedback = (
f"OVER-CLASSIFICATION: Predicted Class 2 but should be Class 1 (Minimal impact). "
f"The information changes are not critical enough to affect diagnosis or patient safety. "
f"Reserve Class 2 ONLY for errors that COULD directly affect diagnosis, treatment, or safety. "
f"The model should distinguish between 'some information missing' vs 'critical information missing'."
)
elif true_label == 2 and pred_label < 2:
feedback = (
f"CRITICAL MISS: Predicted Class {pred_label} but should be Class 2 (Significant impact). "
f"This is a HIGH-PRIORITY error. The transcription contained missing/incorrect information "
f"that COULD affect diagnosis, treatment, or patient safety. The model MUST be MORE SENSITIVE "
f"to clinically critical information like symptoms, medications, measurements, or diagnoses. "
f"Look for: changed medical terms, missing symptoms, altered measurements, or omitted diagnoses."
)
else:
feedback = (
f"MAJOR ERROR: Predicted Class {pred_label} but should be Class {true_label}. "
f"This is a large classification error spanning 2 severity levels. "
f"The model needs to fundamentally reassess its criteria for clinical impact. "
f"Review the distinction between cosmetic changes, information changes, and critical errors."
)
feedback += f" [True: {true_label}, Predicted: {pred_label}]"
return dspy.Prediction(score=score, feedback=feedback)
def simple_metric(example, prediction, trace=None):
true_label = int(example.clinical_impact)
pred_label = parse_label(prediction.clinical_impact)
if pred_label is None:
return -2.0
return COST_MATRIX[true_label][pred_label]

View File

@@ -0,0 +1,3 @@
from .factory import get_optimizer
__all__ = ["get_optimizer"]

View File

@@ -0,0 +1,14 @@
from typing import Any, Optional
from .gepa import build_gepa
from .mipro import build_mipro
def get_optimizer(name: str, metric, reflection_lm: Optional[Any] = None, **kwargs):
"""Return a configured optimizer by name."""
name = name.lower()
if name == "gepa":
return build_gepa(metric=metric, reflection_lm=reflection_lm, **kwargs)
if name in {"mipro", "miprov2"}:
return build_mipro(metric=metric, **kwargs)
raise ValueError(f"Unsupported optimizer: {name}")

View File

@@ -0,0 +1,8 @@
from dspy.teleprompt import GEPA
def build_gepa(metric, reflection_lm=None, **kwargs):
"""Construct a GEPA optimizer."""
if reflection_lm is not None:
kwargs["reflection_lm"] = reflection_lm
return GEPA(metric=metric, **kwargs)

View File

@@ -0,0 +1,7 @@
from dspy.teleprompt import MIPROv2
def build_mipro(metric, **kwargs):
"""Construct a MIPROv2 optimizer."""
allowed = {k: v for k, v in kwargs.items() if k in {"auto", "seed"}}
return MIPROv2(metric=metric, **allowed)

View File

@@ -0,0 +1,3 @@
from .factory import setup_models
__all__ = ["setup_models"]

View File

@@ -0,0 +1,28 @@
from typing import Optional, Tuple
import os
import dspy
from dotenv import load_dotenv
load_dotenv()
def init_bedrock(
task_model: str,
reflection_model: Optional[str] = None,
region: str = "us-east-1",
max_tokens: int = 1000,
) -> Tuple[dspy.LM, Optional[dspy.LM]]:
"""Configure DSPy to use AWS Bedrock."""
task_lm = dspy.LM(
f"bedrock/{task_model}",
region_name=(os.getenv("AWS_REGION", region)),
max_tokens=max_tokens,
)
dspy.settings.configure(lm=task_lm)
reflection_lm = None
if reflection_model:
reflection_lm = dspy.LM(
f"bedrock/{reflection_model}", region_name=region, max_tokens=max_tokens
)
return task_lm, reflection_lm

View File

@@ -0,0 +1,25 @@
from typing import Optional
from .bedrock import init_bedrock
from .gemini import init_gemini
from .openrouter import init_openrouter
from .ollama_chat import init_ollama_chat
def setup_models(
provider: str,
task_model: str,
reflection_model: Optional[str] = None,
**kwargs,
):
"""Initialize task/reflection LMs and configure DSPy."""
provider = provider.lower()
if provider == "gemini":
return init_gemini(task_model, reflection_model=reflection_model, **kwargs)
if provider == "bedrock":
return init_bedrock(task_model, reflection_model=reflection_model, **kwargs)
if provider == "openrouter":
return init_openrouter(task_model, reflection_model=reflection_model, **kwargs)
if provider == "ollam_chat":
return init_ollama_chat(task_model, reflection_model=reflection_model, **kwargs)
raise ValueError(f"Unsupported provider: {provider}")

View File

@@ -0,0 +1,24 @@
import os
from typing import Optional, Tuple
import dspy
import vertexai
def init_gemini(
task_model: str,
reflection_model: Optional[str] = None,
max_tokens: int = 8000,
) -> Tuple[dspy.LM, Optional[dspy.LM]]:
"""Configure DSPy to use Gemini via Vertex AI."""
project = os.getenv("GCP_PROJECT_ID", "your-project-id")
location = os.getenv("GCP_LOCATION", "us-central1")
vertexai.init(project=project, location=location)
task_lm = dspy.LM(f"vertex_ai/{task_model}", max_tokens=max_tokens)
dspy.settings.configure(lm=task_lm)
reflection_lm = None
if reflection_model:
reflection_lm = dspy.LM(f"vertex_ai/{reflection_model}", max_tokens=max_tokens)
return task_lm, reflection_lm

View File

@@ -0,0 +1,28 @@
from typing import Optional, Tuple
import dspy
def init_ollama_chat(
task_model: str,
reflection_model: Optional[str] = None,
max_tokens: int = 1000,
) -> Tuple[dspy.LM, Optional[dspy.LM]]:
"""Configure DSPy to use Ollama Chat."""
task_lm = dspy.LM(
f"ollama_chat/{task_model}",
api_base="http://localhost:11434",
api_key="",
max_tokens=max_tokens,
)
dspy.settings.configure(lm=task_lm)
reflection_lm = None
if reflection_model:
reflection_lm = dspy.LM(
f"ollama_chat/{reflection_model}",
api_base="http://localhost:11434",
api_key="",
max_tokens=max_tokens,
)
return task_lm, reflection_lm

View File

@@ -0,0 +1,38 @@
import os
from typing import Optional, Tuple
import dspy
def init_openrouter(
task_model: str,
reflection_model: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
max_tokens: int = 4000,
) -> Tuple[dspy.LM, Optional[dspy.LM]]:
"""Configure DSPy to use OpenRouter."""
base_url = base_url or os.getenv(
"OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"
)
api_key = api_key or os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise ValueError("OPENROUTER_API_KEY is required for OpenRouter provider")
task_lm = dspy.LM(
f"openrouter/{task_model}",
api_base=base_url,
api_key=api_key,
max_tokens=max_tokens,
)
dspy.settings.configure(lm=task_lm)
reflection_lm = None
if reflection_model:
reflection_lm = dspy.LM(
f"openrouter/{reflection_model}",
api_base=base_url,
api_key=api_key,
max_tokens=max_tokens,
)
return task_lm, reflection_lm

View File

@@ -0,0 +1,58 @@
import dspy
from modaic import PrecompiledAgent, PrecompiledConfig
from typing import Optional
class ClinicalImpactAssessment(dspy.Signature):
"""Assess the clinical impact of transcription errors in medical conversations.
Compare the ground truth conversation with the transcription conversation and determine
if errors would affect patient care. Focus on THREE distinct severity levels.
"""
ground_truth_conversation = dspy.InputField()
transcription_conversation = dspy.InputField()
reasoning = dspy.OutputField(
desc="Brief clinical justification for the assessment."
)
clinical_impact = dspy.OutputField(
desc="""Clinical impact class (return ONLY the number):
0 = No impact: cosmetic differences only (punctuation, capitalization, filler words)
1 = Minimal impact: some information missing/changed but NOT critical to diagnosis or treatment decisions
2 = Significant impact: missing/incorrect information that COULD affect diagnosis, treatment, or patient safety
Return ONLY: 0, 1, or 2"""
)
class ClinicalImpactJudgeConfig(PrecompiledConfig):
task_model: str = "gemini-2.5-pro"
reflection_model: Optional[str] = None # for GEPA runs
max_tokens: int = 8000
temperature: float = 0.1
test_size: Optional[int] = 50
val_size: Optional[int] = 30
seed: Optional[int] = 42
auto: Optional[str] = "medium"
class ClinicalImpactJudge(PrecompiledAgent):
"""LLM Judge for assessing clinical impact."""
config: ClinicalImpactJudgeConfig
def __init__(self, config: ClinicalImpactJudgeConfig, **kwargs):
super().__init__(config, **kwargs)
self.assess = dspy.ChainOfThought(ClinicalImpactAssessment)
self.assess.set_lm(
dspy.LM(
config.task_model,
max_tokens=config.max_tokens,
temperature=config.temperature,
)
)
def forward(self, ground_truth_conversation, transcription_conversation):
return self.assess(
ground_truth_conversation=ground_truth_conversation,
transcription_conversation=transcription_conversation,
)