Files
MedVAL-pipeline/medval/pipeline.py
2025-12-21 00:56:09 -08:00

251 lines
8.9 KiB
Python

import dspy
import os
import json
import pandas as pd
import random
from dspy.datasets import DataLoader
from utils.prompts import (
adversarial_attacks,
adversarial_attack_base,
error_categories,
task_keys,
)
from medval.generator import MedVAL_Generator
from medval.validator import MedVAL_Validator, DetectTask
from dspy.clients.lm_local import LocalProvider
from datasets import load_dataset
from modaic import PrecompiledProgram, PrecompiledConfig
from typing import Literal
def scale_to_unit_interval(val, num_levels):
return (val - 1) / (num_levels - 1)
class MedVALConfig(PrecompiledConfig):
tasks: list[Literal[*task_keys]] = ["report2simplified", "impression2simplified", "report2impression", "bhc2spanish", "query2question", "dialogue2note", "medication2answer"]
model: str = "openai/gpt-4o-mini"
api_base: str | None = None
data: Literal["train", "test"] = "test"
n_samples: int | None = None
debug: bool = False
method: Literal["zero-shot", "finetune"] = "zero-shot"
threshold: float = 0.5
input_csv: str | None = None
class MedVAL(PrecompiledProgram):
config : MedVALConfig
def __init__(self, config: MedVALConfig, api_key: str | None = None, **kwargs):
super().__init__(config, **kwargs)
self.tasks = config.tasks
self.model_name = config.model
self.api_base = config.api_base
self.api_key = api_key
self.data = config.data
self.n_samples = config.n_samples
self.debug = config.debug
self.method = config.method
self.threshold = config.threshold
self.input_csv = config.input_csv
self.student_model = None
self.generator = dspy.ChainOfThought(MedVAL_Generator).deepcopy()
self.validator = dspy.ChainOfThought(MedVAL_Validator).deepcopy()
self.task_detector = dspy.ChainOfThought(DetectTask).deepcopy()
self.prompts = self._load_prompts()
self._configure_lm()
self.dl = DataLoader()
def _load_prompts(self):
with open("utils/task_prompts.json", "r") as file:
return json.load(file)
def _configure_lm(self):
if (self.data == "train") and (self.model_name.startswith("local")):
dspy.settings.experimental = True
lm = dspy.LM(
model=f"openai/local:{'/'.join(self.model_name.split('/')[1:])}",
provider=LocalProvider(),
)
lm.launch()
dspy.configure(lm=lm)
else:
lm = dspy.LM(
model=self.model_name, api_base=self.api_base, api_key=self.api_key
)
if self.student_model != None:
dspy.settings.experimental = True
self.generator.set_lm(lm)
if not self.student_model.startswith("local"):
self.validator.set_lm(
dspy.LM(
model=self.student_model,
api_base=self.api_base,
api_key=self.api_key,
)
)
else:
self.student_model = "/".join(self.student_model.split("/")[1:])
self.validator.set_lm(
dspy.LM(
model=f"openai/local:{self.student_model}",
provider=LocalProvider(),
)
)
else:
dspy.configure(lm=lm)
def load_data(self):
if self.input_csv:
df = pd.read_csv(self.input_csv)
else:
hf_dataset = load_dataset("stanfordmimi/MedVAL-Bench")
dataset_split = (
hf_dataset["train"] if self.data == "train" else hf_dataset["test"]
)
df = dataset_split.to_pandas()
df = df.rename(
columns={
k: v
for k, v in {
"input": "reference",
"reference_output": "target",
"output": "candidate",
}.items()
if k in df.columns and v not in df.columns
}
)
df = df[df["task"].isin(self.tasks)]
df = df.head(self.n_samples) if self.n_samples is not None else df
print(f"\nTasks included: {', '.join(self.tasks)}")
print(f"\nTotal # of samples: {len(df)}\n\n")
df = (
df.sample(frac=1, random_state=42).reset_index(drop=True)
if self.data == "train"
else df.reset_index(drop=True)
)
temp_csv_path = f"temp.csv"
df.to_csv(temp_csv_path, index=False)
if self.data == "train":
full_dataset = self.dl.from_csv(
temp_csv_path,
fields=("reference", "target", "task"),
input_keys=("reference", "target", "task"),
)
os.remove(temp_csv_path)
return full_dataset, None
else:
full_dataset = self.dl.from_csv(
temp_csv_path,
fields=("reference", "target", "task", "candidate"),
input_keys=("reference", "task", "candidate"),
)
os.remove(temp_csv_path)
return df, full_dataset
def generate(self, reference, attack_level, task):
adversarial_instruction = (
self.prompts[task]
+ adversarial_attack_base
+ adversarial_attacks[attack_level - 1]
+ "\n"
+ error_categories
)
result = self.generator(
instruction=adversarial_instruction, reference=reference
)
return result["candidate"]
def forward(self, reference, task, candidate=None, target=None):
if candidate == None:
random.seed(hash(reference) % (2**32))
attack_level = random.randint(1, len(adversarial_attacks))
candidate = self.generate(
reference=reference, attack_level=attack_level, task=task
)
if not task:
task_result = self.task_detector(candidate=candidate, reference=reference)
task = task_result.task
result = self.validator(
instruction=self.prompts[task], reference=reference, candidate=candidate
)
if self.data == "train":
candidate_clean = (
self.generate(reference=reference, attack_level=1, task=task)
if target == None
else target
)
result_clean = self.validator(
instruction=self.prompts[task],
reference=reference,
candidate=candidate_clean,
)
return dspy.Prediction(
reason=result["reasoning"],
err=result["errors"],
attack_prediction=result["risk_level"],
attack_level=attack_level,
clean_prediction=result_clean["risk_level"],
)
return dspy.Prediction(
reason=result["reasoning"],
err=result["errors"],
attack_prediction=result["risk_level"],
)
def validator_metric(self, example, pred, trace=None):
delta = scale_to_unit_interval(
pred["attack_level"], num_levels=len(adversarial_attacks)
)
pred_clean_score = scale_to_unit_interval(
pred["clean_prediction"], num_levels=len(adversarial_attacks)
)
pred_adv_score = scale_to_unit_interval(
pred["attack_prediction"], num_levels=len(adversarial_attacks)
)
absolute_consistency = (pred_clean_score**2) + (pred_adv_score - delta) ** 2
relative_consistency = (pred_adv_score - pred_clean_score - delta) ** 2
total_loss = absolute_consistency + relative_consistency
metric_value = 1 - total_loss / 6
if self.debug:
print(dspy.inspect_history(n=4))
exit()
if (trace is not None) or (self.method == "finetune"):
return metric_value >= self.threshold
return metric_value
def save_results(self, df, method=None):
df = df.where(pd.notnull(df), "None")
df["lm_error_assessment"] = (
df["lm_error_assessment"]
.str.replace("\n\n", "\n", regex=False)
.str.replace("\n \n", "\n", regex=False)
.str.replace("\\n", "\n", regex=False)
)
results_path = f"results/{method}/"
os.makedirs(results_path, exist_ok=True)
if self.input_csv:
csv_name = os.path.splitext(os.path.basename(self.input_csv))[0]
file_path = f"{results_path}{self.model_name.split('/')[-1]}/{csv_name}.csv"
else:
file_path = (
f"{results_path}{self.model_name.split('/')[-1]}/medval-bench.csv"
)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
df.to_csv(file_path, index=False)
print(f"\nResults saved to: {file_path}\n")