import dspy import os import json import pandas as pd import random from dspy.datasets import DataLoader from utils.prompts import ( adversarial_attacks, adversarial_attack_base, error_categories, task_keys, ) from medval.generator import MedVAL_Generator from medval.validator import MedVAL_Validator, DetectTask from dspy.clients.lm_local import LocalProvider from datasets import load_dataset from modaic import PrecompiledProgram, PrecompiledConfig from typing import Literal def scale_to_unit_interval(val, num_levels): return (val - 1) / (num_levels - 1) class MedVALConfig(PrecompiledConfig): tasks: list[Literal[*task_keys]] = ["report2simplified", "impression2simplified", "report2impression", "bhc2spanish", "query2question", "dialogue2note", "medication2answer"] model: str = "openai/gpt-4o-mini" api_base: str | None = None data: Literal["train", "test"] = "test" n_samples: int | None = None debug: bool = False method: Literal["zero-shot", "finetune"] = "zero-shot" threshold: float = 0.5 input_csv: str | None = None class MedVAL(PrecompiledProgram): config : MedVALConfig def __init__(self, config: MedVALConfig, api_key: str | None = None, **kwargs): super().__init__(config, **kwargs) self.tasks = config.tasks self.model_name = config.model self.api_base = config.api_base self.api_key = api_key self.data = config.data self.n_samples = config.n_samples self.debug = config.debug self.method = config.method self.threshold = config.threshold self.input_csv = config.input_csv self.student_model = None self.generator = dspy.ChainOfThought(MedVAL_Generator).deepcopy() self.validator = dspy.ChainOfThought(MedVAL_Validator).deepcopy() self.task_detector = dspy.ChainOfThought(DetectTask).deepcopy() self.prompts = self._load_prompts() self._configure_lm() self.dl = DataLoader() def _load_prompts(self): with open("utils/task_prompts.json", "r") as file: return json.load(file) def _configure_lm(self): if (self.data == "train") and (self.model_name.startswith("local")): dspy.settings.experimental = True lm = dspy.LM( model=f"openai/local:{'/'.join(self.model_name.split('/')[1:])}", provider=LocalProvider(), ) lm.launch() dspy.configure(lm=lm) else: lm = dspy.LM( model=self.model_name, api_base=self.api_base, api_key=self.api_key ) if self.student_model != None: dspy.settings.experimental = True self.generator.set_lm(lm) if not self.student_model.startswith("local"): self.validator.set_lm( dspy.LM( model=self.student_model, api_base=self.api_base, api_key=self.api_key, ) ) else: self.student_model = "/".join(self.student_model.split("/")[1:]) self.validator.set_lm( dspy.LM( model=f"openai/local:{self.student_model}", provider=LocalProvider(), ) ) else: dspy.configure(lm=lm) def load_data(self): if self.input_csv: df = pd.read_csv(self.input_csv) else: hf_dataset = load_dataset("stanfordmimi/MedVAL-Bench") dataset_split = ( hf_dataset["train"] if self.data == "train" else hf_dataset["test"] ) df = dataset_split.to_pandas() df = df.rename( columns={ k: v for k, v in { "input": "reference", "reference_output": "target", "output": "candidate", }.items() if k in df.columns and v not in df.columns } ) df = df[df["task"].isin(self.tasks)] df = df.head(self.n_samples) if self.n_samples is not None else df print(f"\nTasks included: {', '.join(self.tasks)}") print(f"\nTotal # of samples: {len(df)}\n\n") df = ( df.sample(frac=1, random_state=42).reset_index(drop=True) if self.data == "train" else df.reset_index(drop=True) ) temp_csv_path = f"temp.csv" df.to_csv(temp_csv_path, index=False) if self.data == "train": full_dataset = self.dl.from_csv( temp_csv_path, fields=("reference", "target", "task"), input_keys=("reference", "target", "task"), ) os.remove(temp_csv_path) return full_dataset, None else: full_dataset = self.dl.from_csv( temp_csv_path, fields=("reference", "target", "task", "candidate"), input_keys=("reference", "task", "candidate"), ) os.remove(temp_csv_path) return df, full_dataset def generate(self, reference, attack_level, task): adversarial_instruction = ( self.prompts[task] + adversarial_attack_base + adversarial_attacks[attack_level - 1] + "\n" + error_categories ) result = self.generator( instruction=adversarial_instruction, reference=reference ) return result["candidate"] def forward(self, reference, task, candidate=None, target=None): if candidate == None: random.seed(hash(reference) % (2**32)) attack_level = random.randint(1, len(adversarial_attacks)) candidate = self.generate( reference=reference, attack_level=attack_level, task=task ) if not task: task_result = self.task_detector(candidate=candidate, reference=reference) task = task_result.task result = self.validator( instruction=self.prompts[task], reference=reference, candidate=candidate ) if self.data == "train": candidate_clean = ( self.generate(reference=reference, attack_level=1, task=task) if target == None else target ) result_clean = self.validator( instruction=self.prompts[task], reference=reference, candidate=candidate_clean, ) return dspy.Prediction( reason=result["reasoning"], err=result["errors"], attack_prediction=result["risk_level"], attack_level=attack_level, clean_prediction=result_clean["risk_level"], ) return dspy.Prediction( reason=result["reasoning"], err=result["errors"], attack_prediction=result["risk_level"], ) def validator_metric(self, example, pred, trace=None): delta = scale_to_unit_interval( pred["attack_level"], num_levels=len(adversarial_attacks) ) pred_clean_score = scale_to_unit_interval( pred["clean_prediction"], num_levels=len(adversarial_attacks) ) pred_adv_score = scale_to_unit_interval( pred["attack_prediction"], num_levels=len(adversarial_attacks) ) absolute_consistency = (pred_clean_score**2) + (pred_adv_score - delta) ** 2 relative_consistency = (pred_adv_score - pred_clean_score - delta) ** 2 total_loss = absolute_consistency + relative_consistency metric_value = 1 - total_loss / 6 if self.debug: print(dspy.inspect_history(n=4)) exit() if (trace is not None) or (self.method == "finetune"): return metric_value >= self.threshold return metric_value def save_results(self, df, method=None): df = df.where(pd.notnull(df), "None") df["lm_error_assessment"] = ( df["lm_error_assessment"] .str.replace("\n\n", "\n", regex=False) .str.replace("\n \n", "\n", regex=False) .str.replace("\\n", "\n", regex=False) ) results_path = f"results/{method}/" os.makedirs(results_path, exist_ok=True) if self.input_csv: csv_name = os.path.splitext(os.path.basename(self.input_csv))[0] file_path = f"{results_path}{self.model_name.split('/')[-1]}/{csv_name}.csv" else: file_path = ( f"{results_path}{self.model_name.split('/')[-1]}/medval-bench.csv" ) os.makedirs(os.path.dirname(file_path), exist_ok=True) df.to_csv(file_path, index=False) print(f"\nResults saved to: {file_path}\n")