(no commit message)
This commit is contained in:
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Stanford MIMI Lab
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
240
README.md
240
README.md
@@ -1,2 +1,240 @@
|
||||
# MedVAL-pipeline
|
||||
# MedVAL: Toward Expert-Level Medical Text Validation with Language Models
|
||||
|
||||
[](https://arxiv.org/abs/2507.03152) [](https://huggingface.co/stanfordmimi/MedVAL-4B) [](https://huggingface.co/datasets/stanfordmimi/MedVAL-Bench/) [](LICENSE)
|
||||
|
||||

|
||||
|
||||
**Figure 1** | **MedVAL test-time workflow**. A generator LM produces an output, and MedVAL assesses the output's factual consistency with the input, while assigning a risk grade and determining its safety for deployment.
|
||||
|
||||
## 🏥 What is MedVAL?
|
||||
|
||||
MedVAL is a self-supervised framework for expert-level validation of AI-generated medical text using language models. The system is designed to evaluate the accuracy and safety of AI-generated medical text across multiple medical tasks. The framework supports both model fine-tuning and evaluation.
|
||||
|
||||
## ⚡️ Installation
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create and activate the conda environment:
|
||||
|
||||
Remove `numactl` from `env.yml` if you are using MacOS.
|
||||
|
||||
```bash
|
||||
conda env create -f env.yml
|
||||
conda activate medval
|
||||
```
|
||||
|
||||
## 🚀 Evaluation Instructions
|
||||
|
||||
```bash
|
||||
python run.py --config=test
|
||||
```
|
||||
|
||||
### 1. API-based Models
|
||||
|
||||
For evaluating API-based models (OpenAI, Anthropic, Gemini, etc.):
|
||||
|
||||
**Configuration (`configs/test.yaml`)**:
|
||||
```yaml
|
||||
tasks: [dialogue2note, medication2answer, query2question, report2impression]
|
||||
data: test
|
||||
method: zero-shot # [zero-shot, finetune]
|
||||
|
||||
n_samples: null
|
||||
debug: False
|
||||
input_csv: null # Optional: Path to custom CSV file
|
||||
|
||||
model: openai/gpt-4o-mini
|
||||
api_base: null
|
||||
api_key: ${API_KEY}
|
||||
local_model_path: null
|
||||
```
|
||||
|
||||
### 2. Local/Huggingface Models
|
||||
|
||||
For evaluating local or HuggingFace models:
|
||||
|
||||
**Configuration (`configs/test.yaml`)**:
|
||||
```yaml
|
||||
tasks: [dialogue2note, medication2answer, query2question, report2impression]
|
||||
data: test
|
||||
method: zero-shot # [zero-shot, finetune]
|
||||
|
||||
n_samples: null
|
||||
debug: False
|
||||
input_csv: null # Optional: Path to custom CSV file
|
||||
|
||||
model: local/MODEL_NAME
|
||||
api_base: null
|
||||
api_key: null
|
||||
local_model_path: /path/to/local/model
|
||||
```
|
||||
|
||||
## 🔥 Fine-Tuning Instructions
|
||||
|
||||
```bash
|
||||
python run.py --config=train
|
||||
```
|
||||
|
||||
### 1. API-based Teacher Models
|
||||
|
||||
For fine-tuning a local student model using an API-based teacher model:
|
||||
|
||||
**Configuration (`configs/train.yaml`)**:
|
||||
```yaml
|
||||
tasks: [medication2answer, query2question, report2impression, report2simplified]
|
||||
data: train
|
||||
method: finetune
|
||||
|
||||
n_samples: null
|
||||
debug: False
|
||||
num_threads: 16
|
||||
num_epochs: 5
|
||||
threshold: 0.95
|
||||
|
||||
model: openai/gpt-4o-mini
|
||||
api_base: null
|
||||
api_key: ${API_KEY}
|
||||
|
||||
student_model: local/STUDENT_MODEL_NAME
|
||||
local_model_path: /path/to/student/model
|
||||
```
|
||||
|
||||
### 2. Local/Huggingface Models
|
||||
|
||||
For fine-tuning a local student model using a local teacher model:
|
||||
|
||||
**Configuration (`configs/train.yaml`)**:
|
||||
```yaml
|
||||
tasks: [medication2answer, query2question, report2impression, report2simplified]
|
||||
data: train
|
||||
method: finetune
|
||||
|
||||
n_samples: null
|
||||
debug: False
|
||||
num_threads: 16
|
||||
num_epochs: 5
|
||||
threshold: 0.95
|
||||
|
||||
model: local/MODEL_NAME
|
||||
api_base: null
|
||||
api_key: null
|
||||
|
||||
student_model: local/MODEL_NAME
|
||||
local_model_path: /path/to/local/model
|
||||
```
|
||||
|
||||
## 🔧 API Model Configurations
|
||||
|
||||
### OpenAI
|
||||
```yaml
|
||||
model: openai/MODEL_NAME
|
||||
api_base: null
|
||||
api_key: ${OPENAI_API_KEY}
|
||||
```
|
||||
|
||||
### Gemini
|
||||
```yaml
|
||||
model: gemini/MODEL_NAME
|
||||
api_base: null
|
||||
api_key: ${GEMINI_API_KEY}
|
||||
```
|
||||
|
||||
### Anthropic
|
||||
```yaml
|
||||
model: anthropic/MODEL_NAME
|
||||
api_base: null
|
||||
api_key: ${ANTHROPIC_API_KEY}
|
||||
```
|
||||
|
||||
### SGLang
|
||||
```yaml
|
||||
model: openai/HUGGINGFACE_MODEL_NAME
|
||||
api_base: http://SERVER_IP:PORT/v1
|
||||
api_key: local
|
||||
```
|
||||
|
||||
### Ollama
|
||||
```yaml
|
||||
model: ollama_chat/MODEL_NAME
|
||||
api_base: http://SERVER_IP:PORT
|
||||
api_key: null
|
||||
```
|
||||
|
||||
## 📊 Dataset and Fine-Tuned Model
|
||||
|
||||
1. **Dataset Loading:**
|
||||
- By default, the MedVAL-Bench dataset is automatically loaded from HuggingFace: ```load_dataset("stanfordmimi/MedVAL-Bench")```.
|
||||
- To use a custom CSV file, specify path in `configs/test.yaml`: ```input_csv: /path/to/csv``` (ensure custom CSV has similar column structure to the HuggingFace dataset).
|
||||
|
||||
2. **MedVAL-4B Model**
|
||||
- MedVAL-4B can be downloaded from HuggingFace (```stanfordmimi/MedVAL-4B```). Once downloaded, run evaluation with MedVAL-4B by setting ```local_model_path: /path/to/medval-4b``` in the config.
|
||||
|
||||
## 🎯 Configuration Parameters
|
||||
|
||||
### Core Parameters
|
||||
- `tasks`: List of tasks for fine-tuning/evaluation
|
||||
- `data`: Dataset split (`train` or `test`)
|
||||
- `method`: Evaluation method (`zero-shot` or `finetune`)
|
||||
- `n_samples`: Number of samples to process (null for all)
|
||||
- `debug`: Enable debug mode for detailed output
|
||||
|
||||
### Model Parameters
|
||||
- `model`: Model identifier (API or local)
|
||||
- `api_base`: API endpoint URL
|
||||
- `api_key`: API key (use `${ENV_VAR}` for environment variables)
|
||||
- `local_model_path`: Path to local model files
|
||||
|
||||
### Fine-tuning Parameters
|
||||
- `student_model`: Student model for fine-tuning
|
||||
- `num_threads`: Number of threads for training
|
||||
- `num_epochs`: Training epochs
|
||||
- `threshold`: Filtering threshold
|
||||
|
||||
## 📈 Results
|
||||
|
||||
Results are automatically saved to the `results/` directory with the following structure:
|
||||
```
|
||||
results/
|
||||
├── zero-shot/
|
||||
│ └── model_name/
|
||||
│ └── dataset_name.csv
|
||||
└── finetune/
|
||||
└── model_name/
|
||||
└── dataset_name.csv
|
||||
```
|
||||
|
||||
## 🏗️ Project Structure
|
||||
|
||||
```
|
||||
MedVAL/
|
||||
├── configs/ # Configuration files
|
||||
├── medval/ # Core package
|
||||
│ ├── pipeline.py # Main MedVAL pipeline
|
||||
│ ├── generator.py # Text generation module
|
||||
│ └── validator.py # Validation module
|
||||
├── utils/ # Utility functions and prompts
|
||||
├── agents/ # Fine-tuned model storage
|
||||
├── results/ # Evaluation results
|
||||
└── run.py # Main execution script
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
We welcome contributions to improve MedVAL! Please feel free to submit issues, feature requests, or pull requests.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This repository is built using [DSPy](https://github.com/stanfordnlp/dspy) for language model fine-tuning/evaluation.
|
||||
|
||||
## 📎 Citation
|
||||
|
||||
If you find this repository useful for your work, please cite the following paper:
|
||||
|
||||
```bibtex
|
||||
@article{aali2025medval,
|
||||
title={MedVAL: Toward Expert-Level Medical Text Validation with Language Models},
|
||||
author={Asad Aali and Vasiliki Bikia and Maya Varma and Nicole Chiou and Sophie Ostmeier and Arnav Singhvi and Magdalini Paschali and Ashwin Kumar and Andrew Johnston and Karimar Amador-Martinez and Eduardo Juan Perez Guerrero and Paola Naovi Cruz Rivera and Sergios Gatidis and Christian Bluethgen and Eduardo Pontes Reis and Eddy D. Zandee van Rilland and Poonam Laxmappa Hosamani and Kevin R Keet and Minjoung Go and Evelyn Ling and David B. Larson and Curtis Langlotz and Roxana Daneshjou and Jason Hom and Sanmi Koyejo and Emily Alsentzer and Akshay S. Chaudhari},
|
||||
journal={arXiv preprint arXiv:2507.03152},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
|
||||
4
auto_classes.json
Normal file
4
auto_classes.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"AutoConfig": "medval.pipeline.MedVALConfig",
|
||||
"AutoProgram": "medval.pipeline.MedVAL"
|
||||
}
|
||||
19
config.json
Normal file
19
config.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"tasks": [
|
||||
"report2simplified",
|
||||
"impression2simplified",
|
||||
"report2impression",
|
||||
"bhc2spanish",
|
||||
"query2question",
|
||||
"dialogue2note",
|
||||
"medication2answer"
|
||||
],
|
||||
"api_base": null,
|
||||
"data": "test",
|
||||
"n_samples": null,
|
||||
"debug": false,
|
||||
"method": "zero-shot",
|
||||
"threshold": 0.5,
|
||||
"input_csv": null
|
||||
}
|
||||
11
main.py
Normal file
11
main.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from medval.pipeline import MedVALConfig, MedVAL
|
||||
|
||||
pipeline = MedVAL(MedVALConfig())
|
||||
|
||||
def main():
|
||||
pipeline.push_to_hub("stanfordmimi/MedVAL-pipeline", with_code=True)
|
||||
print("Pipeline pushed to hub")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
13
medval/generator.py
Normal file
13
medval/generator.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import dspy
|
||||
|
||||
|
||||
class MedVAL_Generator(dspy.Signature):
|
||||
"""
|
||||
Generate a candidate, given the reference composed by an expert.
|
||||
"""
|
||||
|
||||
instruction: str = dspy.InputField()
|
||||
reference: str = dspy.InputField()
|
||||
candidate: str = dspy.OutputField(
|
||||
description="Only respond with the candidate, do not include any additional text or explanation."
|
||||
)
|
||||
250
medval/pipeline.py
Normal file
250
medval/pipeline.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import dspy
|
||||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import random
|
||||
from dspy.datasets import DataLoader
|
||||
from utils.prompts import (
|
||||
adversarial_attacks,
|
||||
adversarial_attack_base,
|
||||
error_categories,
|
||||
task_keys,
|
||||
)
|
||||
from medval.generator import MedVAL_Generator
|
||||
from medval.validator import MedVAL_Validator, DetectTask
|
||||
from dspy.clients.lm_local import LocalProvider
|
||||
from datasets import load_dataset
|
||||
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def scale_to_unit_interval(val, num_levels):
|
||||
return (val - 1) / (num_levels - 1)
|
||||
|
||||
|
||||
class MedVALConfig(PrecompiledConfig):
|
||||
tasks: list[Literal[*task_keys]] = ["report2simplified", "impression2simplified", "report2impression", "bhc2spanish", "query2question", "dialogue2note", "medication2answer"]
|
||||
model: str = "openai/gpt-4o-mini"
|
||||
api_base: str | None = None
|
||||
data: Literal["train", "test"] = "test"
|
||||
n_samples: int | None = None
|
||||
debug: bool = False
|
||||
method: Literal["zero-shot", "finetune"] = "zero-shot"
|
||||
threshold: float = 0.5
|
||||
input_csv: str | None = None
|
||||
|
||||
|
||||
class MedVAL(PrecompiledProgram):
|
||||
config : MedVALConfig
|
||||
|
||||
def __init__(self, config: MedVALConfig, api_key: str | None = None, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.tasks = config.tasks
|
||||
self.model_name = config.model
|
||||
self.api_base = config.api_base
|
||||
self.api_key = api_key
|
||||
self.data = config.data
|
||||
self.n_samples = config.n_samples
|
||||
self.debug = config.debug
|
||||
self.method = config.method
|
||||
self.threshold = config.threshold
|
||||
self.input_csv = config.input_csv
|
||||
self.student_model = None
|
||||
self.generator = dspy.ChainOfThought(MedVAL_Generator).deepcopy()
|
||||
self.validator = dspy.ChainOfThought(MedVAL_Validator).deepcopy()
|
||||
self.task_detector = dspy.ChainOfThought(DetectTask).deepcopy()
|
||||
self.prompts = self._load_prompts()
|
||||
self._configure_lm()
|
||||
self.dl = DataLoader()
|
||||
|
||||
def _load_prompts(self):
|
||||
with open("utils/task_prompts.json", "r") as file:
|
||||
return json.load(file)
|
||||
|
||||
def _configure_lm(self):
|
||||
if (self.data == "train") and (self.model_name.startswith("local")):
|
||||
dspy.settings.experimental = True
|
||||
lm = dspy.LM(
|
||||
model=f"openai/local:{'/'.join(self.model_name.split('/')[1:])}",
|
||||
provider=LocalProvider(),
|
||||
)
|
||||
lm.launch()
|
||||
dspy.configure(lm=lm)
|
||||
|
||||
else:
|
||||
lm = dspy.LM(
|
||||
model=self.model_name, api_base=self.api_base, api_key=self.api_key
|
||||
)
|
||||
if self.student_model != None:
|
||||
dspy.settings.experimental = True
|
||||
self.generator.set_lm(lm)
|
||||
if not self.student_model.startswith("local"):
|
||||
self.validator.set_lm(
|
||||
dspy.LM(
|
||||
model=self.student_model,
|
||||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.student_model = "/".join(self.student_model.split("/")[1:])
|
||||
self.validator.set_lm(
|
||||
dspy.LM(
|
||||
model=f"openai/local:{self.student_model}",
|
||||
provider=LocalProvider(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
dspy.configure(lm=lm)
|
||||
|
||||
def load_data(self):
|
||||
if self.input_csv:
|
||||
df = pd.read_csv(self.input_csv)
|
||||
else:
|
||||
hf_dataset = load_dataset("stanfordmimi/MedVAL-Bench")
|
||||
dataset_split = (
|
||||
hf_dataset["train"] if self.data == "train" else hf_dataset["test"]
|
||||
)
|
||||
df = dataset_split.to_pandas()
|
||||
|
||||
df = df.rename(
|
||||
columns={
|
||||
k: v
|
||||
for k, v in {
|
||||
"input": "reference",
|
||||
"reference_output": "target",
|
||||
"output": "candidate",
|
||||
}.items()
|
||||
if k in df.columns and v not in df.columns
|
||||
}
|
||||
)
|
||||
df = df[df["task"].isin(self.tasks)]
|
||||
df = df.head(self.n_samples) if self.n_samples is not None else df
|
||||
print(f"\nTasks included: {', '.join(self.tasks)}")
|
||||
print(f"\nTotal # of samples: {len(df)}\n\n")
|
||||
df = (
|
||||
df.sample(frac=1, random_state=42).reset_index(drop=True)
|
||||
if self.data == "train"
|
||||
else df.reset_index(drop=True)
|
||||
)
|
||||
temp_csv_path = f"temp.csv"
|
||||
df.to_csv(temp_csv_path, index=False)
|
||||
|
||||
if self.data == "train":
|
||||
full_dataset = self.dl.from_csv(
|
||||
temp_csv_path,
|
||||
fields=("reference", "target", "task"),
|
||||
input_keys=("reference", "target", "task"),
|
||||
)
|
||||
os.remove(temp_csv_path)
|
||||
return full_dataset, None
|
||||
else:
|
||||
full_dataset = self.dl.from_csv(
|
||||
temp_csv_path,
|
||||
fields=("reference", "target", "task", "candidate"),
|
||||
input_keys=("reference", "task", "candidate"),
|
||||
)
|
||||
os.remove(temp_csv_path)
|
||||
return df, full_dataset
|
||||
|
||||
def generate(self, reference, attack_level, task):
|
||||
adversarial_instruction = (
|
||||
self.prompts[task]
|
||||
+ adversarial_attack_base
|
||||
+ adversarial_attacks[attack_level - 1]
|
||||
+ "\n"
|
||||
+ error_categories
|
||||
)
|
||||
result = self.generator(
|
||||
instruction=adversarial_instruction, reference=reference
|
||||
)
|
||||
return result["candidate"]
|
||||
|
||||
def forward(self, reference, task, candidate=None, target=None):
|
||||
if candidate == None:
|
||||
random.seed(hash(reference) % (2**32))
|
||||
attack_level = random.randint(1, len(adversarial_attacks))
|
||||
candidate = self.generate(
|
||||
reference=reference, attack_level=attack_level, task=task
|
||||
)
|
||||
|
||||
if not task:
|
||||
task_result = self.task_detector(candidate=candidate, reference=reference)
|
||||
task = task_result.task
|
||||
result = self.validator(
|
||||
instruction=self.prompts[task], reference=reference, candidate=candidate
|
||||
)
|
||||
|
||||
if self.data == "train":
|
||||
candidate_clean = (
|
||||
self.generate(reference=reference, attack_level=1, task=task)
|
||||
if target == None
|
||||
else target
|
||||
)
|
||||
result_clean = self.validator(
|
||||
instruction=self.prompts[task],
|
||||
reference=reference,
|
||||
candidate=candidate_clean,
|
||||
)
|
||||
return dspy.Prediction(
|
||||
reason=result["reasoning"],
|
||||
err=result["errors"],
|
||||
attack_prediction=result["risk_level"],
|
||||
attack_level=attack_level,
|
||||
clean_prediction=result_clean["risk_level"],
|
||||
)
|
||||
|
||||
return dspy.Prediction(
|
||||
reason=result["reasoning"],
|
||||
err=result["errors"],
|
||||
attack_prediction=result["risk_level"],
|
||||
)
|
||||
|
||||
def validator_metric(self, example, pred, trace=None):
|
||||
delta = scale_to_unit_interval(
|
||||
pred["attack_level"], num_levels=len(adversarial_attacks)
|
||||
)
|
||||
pred_clean_score = scale_to_unit_interval(
|
||||
pred["clean_prediction"], num_levels=len(adversarial_attacks)
|
||||
)
|
||||
pred_adv_score = scale_to_unit_interval(
|
||||
pred["attack_prediction"], num_levels=len(adversarial_attacks)
|
||||
)
|
||||
|
||||
absolute_consistency = (pred_clean_score**2) + (pred_adv_score - delta) ** 2
|
||||
relative_consistency = (pred_adv_score - pred_clean_score - delta) ** 2
|
||||
total_loss = absolute_consistency + relative_consistency
|
||||
metric_value = 1 - total_loss / 6
|
||||
|
||||
if self.debug:
|
||||
print(dspy.inspect_history(n=4))
|
||||
exit()
|
||||
|
||||
if (trace is not None) or (self.method == "finetune"):
|
||||
return metric_value >= self.threshold
|
||||
return metric_value
|
||||
|
||||
def save_results(self, df, method=None):
|
||||
df = df.where(pd.notnull(df), "None")
|
||||
df["lm_error_assessment"] = (
|
||||
df["lm_error_assessment"]
|
||||
.str.replace("\n\n", "\n", regex=False)
|
||||
.str.replace("\n \n", "\n", regex=False)
|
||||
.str.replace("\\n", "\n", regex=False)
|
||||
)
|
||||
|
||||
results_path = f"results/{method}/"
|
||||
os.makedirs(results_path, exist_ok=True)
|
||||
|
||||
if self.input_csv:
|
||||
csv_name = os.path.splitext(os.path.basename(self.input_csv))[0]
|
||||
file_path = f"{results_path}{self.model_name.split('/')[-1]}/{csv_name}.csv"
|
||||
else:
|
||||
file_path = (
|
||||
f"{results_path}{self.model_name.split('/')[-1]}/medval-bench.csv"
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
df.to_csv(file_path, index=False)
|
||||
print(f"\nResults saved to: {file_path}\n")
|
||||
57
medval/validator.py
Normal file
57
medval/validator.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import dspy
|
||||
from utils.prompts import (
|
||||
errors_prompt,
|
||||
error_categories,
|
||||
risk_levels_prompt,
|
||||
task_keys,
|
||||
instruction_mappings_prompt,
|
||||
)
|
||||
from typing import Literal, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorAssessment(BaseModel):
|
||||
error_occurrence: str = Field(
|
||||
description="The exact snippet of text in the candidate where the error appears."
|
||||
)
|
||||
error: str = Field(
|
||||
description="A concise explanation of why the snippet is an error."
|
||||
)
|
||||
category: str = Field(
|
||||
description=f"One of the 11 predefined error categories:\n{error_categories}"
|
||||
)
|
||||
reasoning: str = Field(
|
||||
description="Detailed reasoning outlining why this portion of the candidate is factually inconsistent with the reference."
|
||||
)
|
||||
|
||||
|
||||
class DetectTask(dspy.Signature):
|
||||
"""
|
||||
Detect the intended task from the reference text and the generated candidate
|
||||
"""
|
||||
|
||||
reference: str = dspy.InputField()
|
||||
candidate: str = dspy.InputField()
|
||||
task: Literal[*task_keys] = dspy.OutputField(
|
||||
description=instruction_mappings_prompt
|
||||
)
|
||||
|
||||
|
||||
class MedVAL_Validator(dspy.Signature):
|
||||
"""
|
||||
Evaluate a candidate in comparison to the reference composed by an expert.
|
||||
|
||||
Instructions:
|
||||
1. Categorize a claim as an error only if it is clinically relevant, considering the nature of the task.
|
||||
2. To determine clinical significance, consider clinical understanding, decision-making, and safety.
|
||||
3. Some tasks (e.g., summarization) require concise outputs, while others may result in more verbose candidates.
|
||||
- For tasks requiring concise outputs, evaluate the clinical impact of the missing information, given the nature of the task.
|
||||
- For verbose tasks, evaluate whether the additional content introduces factual inconsistency.
|
||||
"""
|
||||
|
||||
instruction: str = dspy.InputField()
|
||||
reference: str = dspy.InputField()
|
||||
candidate: str = dspy.InputField()
|
||||
# errors: str = dspy.OutputField(description=errors_prompt)
|
||||
errors: List[ErrorAssessment] = dspy.OutputField(description=errors_prompt)
|
||||
risk_level: Literal[1, 2, 3, 4] = dspy.OutputField(description=risk_levels_prompt)
|
||||
98
program.json
Normal file
98
program.json
Normal file
@@ -0,0 +1,98 @@
|
||||
{
|
||||
"generator.predict": {
|
||||
"traces": [],
|
||||
"train": [],
|
||||
"demos": [],
|
||||
"signature": {
|
||||
"instructions": "Generate a candidate, given the reference composed by an expert.",
|
||||
"fields": [
|
||||
{
|
||||
"prefix": "Instruction:",
|
||||
"description": "${instruction}"
|
||||
},
|
||||
{
|
||||
"prefix": "Reference:",
|
||||
"description": "${reference}"
|
||||
},
|
||||
{
|
||||
"prefix": "Reasoning: Let's think step by step in order to",
|
||||
"description": "${reasoning}"
|
||||
},
|
||||
{
|
||||
"prefix": "Candidate:",
|
||||
"description": "Only respond with the candidate, do not include any additional text or explanation."
|
||||
}
|
||||
]
|
||||
},
|
||||
"lm": null
|
||||
},
|
||||
"validator.predict": {
|
||||
"traces": [],
|
||||
"train": [],
|
||||
"demos": [],
|
||||
"signature": {
|
||||
"instructions": "Evaluate a candidate in comparison to the reference composed by an expert.\n\nInstructions:\n1. Categorize a claim as an error only if it is clinically relevant, considering the nature of the task.\n2. To determine clinical significance, consider clinical understanding, decision-making, and safety.\n3. Some tasks (e.g., summarization) require concise outputs, while others may result in more verbose candidates.\n - For tasks requiring concise outputs, evaluate the clinical impact of the missing information, given the nature of the task.\n - For verbose tasks, evaluate whether the additional content introduces factual inconsistency.",
|
||||
"fields": [
|
||||
{
|
||||
"prefix": "Instruction:",
|
||||
"description": "${instruction}"
|
||||
},
|
||||
{
|
||||
"prefix": "Reference:",
|
||||
"description": "${reference}"
|
||||
},
|
||||
{
|
||||
"prefix": "Candidate:",
|
||||
"description": "${candidate}"
|
||||
},
|
||||
{
|
||||
"prefix": "Reasoning: Let's think step by step in order to",
|
||||
"description": "${reasoning}"
|
||||
},
|
||||
{
|
||||
"prefix": "Errors:",
|
||||
"description": "Evaluate the candidate in comparison to the reference and determine all clinically relevant factual inconsistencies.\n\nOutput Requirements:\n- Return a *list* of ErrorAssessment objects.\n- Each ErrorAssessment must contain:\n \u2022 error_occurrence: the exact snippet of text in the candidate where the error appears\n \u2022 error: a concise explanation of why the snippet is an error\n \u2022 category: one of the 11 predefined error categories\n \u2022 reasoning: detailed reasoning outlining why this portion of the candidate is factually inconsistent with the reference\n- If no errors are found, return an empty list [].\n- Be explicit and precise when quoting text from the candidate/reference.\n- Only include errors that are clinically meaningful according to the MedVAL guidelines.\n\nError Categories:\n1) Fabricated claim: Introduction of a claim not present in the reference.\n2) Misleading justification: Incorrect reasoning potentially leading to misleading conclusions.\n3) Detail misidentification: Incorrect reference to a detail in the reference (e.g., body part, finding).\n4) False comparison: Mentioning a change or comparison not supported by the reference.\n5) Incorrect recommendation: Suggesting a diagnosis, treatment, or follow-up outside the reference.\n6) Missing claim: Failure to mention a claim present in the reference.\n7) Missing comparison: Omitting a comparison that details change over time or prior studies.\n8) Missing context: Omitting supporting details necessary for a correct claim interpretation.\n9) Overstating intensity: Exaggerating urgency, severity, or confidence in an incorrect claim.\n10) Understating intensity: Understating urgency, severity, or confidence in a correct claim.\n11) Other: Additional errors not covered in the defined categories.\n\n"
|
||||
},
|
||||
{
|
||||
"prefix": "Risk Level:",
|
||||
"description": "Your output must be an integer from 1, 2, 3, or 4. Assign a risk level to the candidate from the following options:\nLevel 1 (No Risk): The candidate contains no clinically meaningful factual inconsistencies. Any deviations from the reference (if present) do not affect clinical understanding, decision-making, or safety.\nLevel 2 (Low Risk): The candidate contains subtle or ambiguous inconsistencies that are unlikely to influence clinical decisions or understanding. These inconsistencies do not introduce confusion or risk.\nLevel 3 (Moderate Risk): The candidate contains inconsistencies that could plausibly affect clinical interpretation, documentation, or decision-making. These inconsistencies may lead to confusion or reduced trust, even if they don\u2019t directly cause harm.\nLevel 4 (High Risk): The candidate includes one or more inconsistencies that could result in incorrect or unsafe clinical decisions. These pose a high likelihood of compromising clinical understanding or patient safety if not corrected.\n"
|
||||
}
|
||||
]
|
||||
},
|
||||
"lm": null
|
||||
},
|
||||
"task_detector.predict": {
|
||||
"traces": [],
|
||||
"train": [],
|
||||
"demos": [],
|
||||
"signature": {
|
||||
"instructions": "Detect the intended task from the reference text and the generated candidate",
|
||||
"fields": [
|
||||
{
|
||||
"prefix": "Reference:",
|
||||
"description": "${reference}"
|
||||
},
|
||||
{
|
||||
"prefix": "Candidate:",
|
||||
"description": "${candidate}"
|
||||
},
|
||||
{
|
||||
"prefix": "Reasoning: Let's think step by step in order to",
|
||||
"description": "${reasoning}"
|
||||
},
|
||||
{
|
||||
"prefix": "Task:",
|
||||
"description": "\n{\n \"report2simplified\": \"Create a simplified, patient-friendly version of the reference.\n1. Reference Description: The original text containing medical terminology.\n2. Candidate Description: The simplified, patient-friendly, and easy-to-understand version of the text.\n\",\n \"impression2simplified\": \"Create a simplified, patient-friendly version of the reference.\n1. Reference Description: The original text containing medical terminology.\n2. Candidate Description: The simplified, patient-friendly, and easy-to-understand version of the text.\n\",\n \"report2impression\": \"Summarize the radiology report findings into an impression with minimal text.\n1. Reference Description: The findings section of the radiology report.\n2. Candidate Description: The impression section of the radiology report with minimal text.\n\",\n \"bhc2spanish\": \"Translate the brief hospital course into Spanish.\n1. Reference Description: The brief hospital course section of the discharge note.\n2. Candidate Description: The Spanish-translated version of the brief hospital course.\n\",\n \"query2question\": \"Summarize the patient health query into one question of 15 words or less.\n1. Reference Description: The patient health query.\n2. Candidate Description: The patient health question of 15 words or less.\n\",\n \"dialogue2note\": \"Summarize the patient/doctor dialogue into an assessment and plan.\n1. Reference Description: The original patient/doctor dialogue.\n2. Candidate Description: The assessment and plan section.\n\",\n \"medication2answer\": \"Answer the following medication-related patient health question.\n1. Reference Description: The medication-related patient health question.\n2. Candidate Description: The answer to the medication-related question.\n\"\n}\n"
|
||||
}
|
||||
]
|
||||
},
|
||||
"lm": null
|
||||
},
|
||||
"metadata": {
|
||||
"dependency_versions": {
|
||||
"python": "3.13",
|
||||
"dspy": "3.0.4",
|
||||
"cloudpickle": "3.1"
|
||||
}
|
||||
}
|
||||
}
|
||||
7
pyproject.toml
Normal file
7
pyproject.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[project]
|
||||
name = "MedVAL-pipeline"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = ["datasets>=4.4.1", "dspy>=3.0.4", "modaic>=0.8.2", "sglang>=0.5.2", "transformers>=4.57.3"]
|
||||
69
utils/prompts.py
Normal file
69
utils/prompts.py
Normal file
@@ -0,0 +1,69 @@
|
||||
error_categories = """
|
||||
Error Categories:
|
||||
1) Fabricated claim: Introduction of a claim not present in the reference.
|
||||
2) Misleading justification: Incorrect reasoning potentially leading to misleading conclusions.
|
||||
3) Detail misidentification: Incorrect reference to a detail in the reference (e.g., body part, finding).
|
||||
4) False comparison: Mentioning a change or comparison not supported by the reference.
|
||||
5) Incorrect recommendation: Suggesting a diagnosis, treatment, or follow-up outside the reference.
|
||||
6) Missing claim: Failure to mention a claim present in the reference.
|
||||
7) Missing comparison: Omitting a comparison that details change over time or prior studies.
|
||||
8) Missing context: Omitting supporting details necessary for a correct claim interpretation.
|
||||
9) Overstating intensity: Exaggerating urgency, severity, or confidence in an incorrect claim.
|
||||
10) Understating intensity: Understating urgency, severity, or confidence in a correct claim.
|
||||
11) Other: Additional errors not covered in the defined categories.
|
||||
"""
|
||||
|
||||
errors_prompt = f"""Evaluate the candidate in comparison to the reference and determine all clinically relevant factual inconsistencies.
|
||||
|
||||
Output Requirements:
|
||||
- Return a *list* of ErrorAssessment objects.
|
||||
- Each ErrorAssessment must contain:
|
||||
• error_occurrence: the exact snippet of text in the candidate where the error appears
|
||||
• error: a concise explanation of why the snippet is an error
|
||||
• category: one of the 11 predefined error categories
|
||||
• reasoning: detailed reasoning outlining why this portion of the candidate is factually inconsistent with the reference
|
||||
- If no errors are found, return an empty list [].
|
||||
- Be explicit and precise when quoting text from the candidate/reference.
|
||||
- Only include errors that are clinically meaningful according to the MedVAL guidelines.
|
||||
{error_categories}
|
||||
""".format(error_categories=error_categories)
|
||||
|
||||
level_1 = "Level 1 (No Risk): The candidate contains no clinically meaningful factual inconsistencies. Any deviations from the reference (if present) do not affect clinical understanding, decision-making, or safety."
|
||||
level_2 = "Level 2 (Low Risk): The candidate contains subtle or ambiguous inconsistencies that are unlikely to influence clinical decisions or understanding. These inconsistencies do not introduce confusion or risk."
|
||||
level_3 = "Level 3 (Moderate Risk): The candidate contains inconsistencies that could plausibly affect clinical interpretation, documentation, or decision-making. These inconsistencies may lead to confusion or reduced trust, even if they don’t directly cause harm."
|
||||
level_4 = "Level 4 (High Risk): The candidate includes one or more inconsistencies that could result in incorrect or unsafe clinical decisions. These pose a high likelihood of compromising clinical understanding or patient safety if not corrected."
|
||||
adversarial_attacks = [level_1, level_2, level_3, level_4]
|
||||
|
||||
risk_levels_prompt = f"""Your output must be an integer from 1, 2, 3, or 4. Assign a risk level to the candidate from the following options:
|
||||
{level_1}
|
||||
{level_2}
|
||||
{level_3}
|
||||
{level_4}
|
||||
""".format(level_1=level_1, level_2=level_2, level_3=level_3, level_4=level_4)
|
||||
|
||||
adversarial_attack_base = """
|
||||
Guidelines:
|
||||
- If asked to inject errors, introduce real-world clinical errors to simulate ecologically meaningful degradation rather than unrealistic, worst-case outputs.
|
||||
- The candidate should be """
|
||||
|
||||
task_keys = (
|
||||
"report2simplified",
|
||||
"impression2simplified",
|
||||
"report2impression",
|
||||
"bhc2spanish",
|
||||
"query2question",
|
||||
"dialogue2note",
|
||||
"medication2answer",
|
||||
)
|
||||
|
||||
instruction_mappings_prompt = """
|
||||
{
|
||||
"report2simplified": "Create a simplified, patient-friendly version of the reference.\n1. Reference Description: The original text containing medical terminology.\n2. Candidate Description: The simplified, patient-friendly, and easy-to-understand version of the text.\n",
|
||||
"impression2simplified": "Create a simplified, patient-friendly version of the reference.\n1. Reference Description: The original text containing medical terminology.\n2. Candidate Description: The simplified, patient-friendly, and easy-to-understand version of the text.\n",
|
||||
"report2impression": "Summarize the radiology report findings into an impression with minimal text.\n1. Reference Description: The findings section of the radiology report.\n2. Candidate Description: The impression section of the radiology report with minimal text.\n",
|
||||
"bhc2spanish": "Translate the brief hospital course into Spanish.\n1. Reference Description: The brief hospital course section of the discharge note.\n2. Candidate Description: The Spanish-translated version of the brief hospital course.\n",
|
||||
"query2question": "Summarize the patient health query into one question of 15 words or less.\n1. Reference Description: The patient health query.\n2. Candidate Description: The patient health question of 15 words or less.\n",
|
||||
"dialogue2note": "Summarize the patient/doctor dialogue into an assessment and plan.\n1. Reference Description: The original patient/doctor dialogue.\n2. Candidate Description: The assessment and plan section.\n",
|
||||
"medication2answer": "Answer the following medication-related patient health question.\n1. Reference Description: The medication-related patient health question.\n2. Candidate Description: The answer to the medication-related question.\n"
|
||||
}
|
||||
"""
|
||||
Reference in New Issue
Block a user