(no commit message)
This commit is contained in:
13
medval/generator.py
Normal file
13
medval/generator.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import dspy
|
||||
|
||||
|
||||
class MedVAL_Generator(dspy.Signature):
|
||||
"""
|
||||
Generate a candidate, given the reference composed by an expert.
|
||||
"""
|
||||
|
||||
instruction: str = dspy.InputField()
|
||||
reference: str = dspy.InputField()
|
||||
candidate: str = dspy.OutputField(
|
||||
description="Only respond with the candidate, do not include any additional text or explanation."
|
||||
)
|
||||
250
medval/pipeline.py
Normal file
250
medval/pipeline.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import dspy
|
||||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import random
|
||||
from dspy.datasets import DataLoader
|
||||
from utils.prompts import (
|
||||
adversarial_attacks,
|
||||
adversarial_attack_base,
|
||||
error_categories,
|
||||
task_keys,
|
||||
)
|
||||
from medval.generator import MedVAL_Generator
|
||||
from medval.validator import MedVAL_Validator, DetectTask
|
||||
from dspy.clients.lm_local import LocalProvider
|
||||
from datasets import load_dataset
|
||||
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def scale_to_unit_interval(val, num_levels):
|
||||
return (val - 1) / (num_levels - 1)
|
||||
|
||||
|
||||
class MedVALConfig(PrecompiledConfig):
|
||||
tasks: list[Literal[*task_keys]] = ["report2simplified", "impression2simplified", "report2impression", "bhc2spanish", "query2question", "dialogue2note", "medication2answer"]
|
||||
model: str = "openai/gpt-4o-mini"
|
||||
api_base: str | None = None
|
||||
data: Literal["train", "test"] = "test"
|
||||
n_samples: int | None = None
|
||||
debug: bool = False
|
||||
method: Literal["zero-shot", "finetune"] = "zero-shot"
|
||||
threshold: float = 0.5
|
||||
input_csv: str | None = None
|
||||
|
||||
|
||||
class MedVAL(PrecompiledProgram):
|
||||
config : MedVALConfig
|
||||
|
||||
def __init__(self, config: MedVALConfig, api_key: str | None = None, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.tasks = config.tasks
|
||||
self.model_name = config.model
|
||||
self.api_base = config.api_base
|
||||
self.api_key = api_key
|
||||
self.data = config.data
|
||||
self.n_samples = config.n_samples
|
||||
self.debug = config.debug
|
||||
self.method = config.method
|
||||
self.threshold = config.threshold
|
||||
self.input_csv = config.input_csv
|
||||
self.student_model = None
|
||||
self.generator = dspy.ChainOfThought(MedVAL_Generator).deepcopy()
|
||||
self.validator = dspy.ChainOfThought(MedVAL_Validator).deepcopy()
|
||||
self.task_detector = dspy.ChainOfThought(DetectTask).deepcopy()
|
||||
self.prompts = self._load_prompts()
|
||||
self._configure_lm()
|
||||
self.dl = DataLoader()
|
||||
|
||||
def _load_prompts(self):
|
||||
with open("utils/task_prompts.json", "r") as file:
|
||||
return json.load(file)
|
||||
|
||||
def _configure_lm(self):
|
||||
if (self.data == "train") and (self.model_name.startswith("local")):
|
||||
dspy.settings.experimental = True
|
||||
lm = dspy.LM(
|
||||
model=f"openai/local:{'/'.join(self.model_name.split('/')[1:])}",
|
||||
provider=LocalProvider(),
|
||||
)
|
||||
lm.launch()
|
||||
dspy.configure(lm=lm)
|
||||
|
||||
else:
|
||||
lm = dspy.LM(
|
||||
model=self.model_name, api_base=self.api_base, api_key=self.api_key
|
||||
)
|
||||
if self.student_model != None:
|
||||
dspy.settings.experimental = True
|
||||
self.generator.set_lm(lm)
|
||||
if not self.student_model.startswith("local"):
|
||||
self.validator.set_lm(
|
||||
dspy.LM(
|
||||
model=self.student_model,
|
||||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.student_model = "/".join(self.student_model.split("/")[1:])
|
||||
self.validator.set_lm(
|
||||
dspy.LM(
|
||||
model=f"openai/local:{self.student_model}",
|
||||
provider=LocalProvider(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
dspy.configure(lm=lm)
|
||||
|
||||
def load_data(self):
|
||||
if self.input_csv:
|
||||
df = pd.read_csv(self.input_csv)
|
||||
else:
|
||||
hf_dataset = load_dataset("stanfordmimi/MedVAL-Bench")
|
||||
dataset_split = (
|
||||
hf_dataset["train"] if self.data == "train" else hf_dataset["test"]
|
||||
)
|
||||
df = dataset_split.to_pandas()
|
||||
|
||||
df = df.rename(
|
||||
columns={
|
||||
k: v
|
||||
for k, v in {
|
||||
"input": "reference",
|
||||
"reference_output": "target",
|
||||
"output": "candidate",
|
||||
}.items()
|
||||
if k in df.columns and v not in df.columns
|
||||
}
|
||||
)
|
||||
df = df[df["task"].isin(self.tasks)]
|
||||
df = df.head(self.n_samples) if self.n_samples is not None else df
|
||||
print(f"\nTasks included: {', '.join(self.tasks)}")
|
||||
print(f"\nTotal # of samples: {len(df)}\n\n")
|
||||
df = (
|
||||
df.sample(frac=1, random_state=42).reset_index(drop=True)
|
||||
if self.data == "train"
|
||||
else df.reset_index(drop=True)
|
||||
)
|
||||
temp_csv_path = f"temp.csv"
|
||||
df.to_csv(temp_csv_path, index=False)
|
||||
|
||||
if self.data == "train":
|
||||
full_dataset = self.dl.from_csv(
|
||||
temp_csv_path,
|
||||
fields=("reference", "target", "task"),
|
||||
input_keys=("reference", "target", "task"),
|
||||
)
|
||||
os.remove(temp_csv_path)
|
||||
return full_dataset, None
|
||||
else:
|
||||
full_dataset = self.dl.from_csv(
|
||||
temp_csv_path,
|
||||
fields=("reference", "target", "task", "candidate"),
|
||||
input_keys=("reference", "task", "candidate"),
|
||||
)
|
||||
os.remove(temp_csv_path)
|
||||
return df, full_dataset
|
||||
|
||||
def generate(self, reference, attack_level, task):
|
||||
adversarial_instruction = (
|
||||
self.prompts[task]
|
||||
+ adversarial_attack_base
|
||||
+ adversarial_attacks[attack_level - 1]
|
||||
+ "\n"
|
||||
+ error_categories
|
||||
)
|
||||
result = self.generator(
|
||||
instruction=adversarial_instruction, reference=reference
|
||||
)
|
||||
return result["candidate"]
|
||||
|
||||
def forward(self, reference, task, candidate=None, target=None):
|
||||
if candidate == None:
|
||||
random.seed(hash(reference) % (2**32))
|
||||
attack_level = random.randint(1, len(adversarial_attacks))
|
||||
candidate = self.generate(
|
||||
reference=reference, attack_level=attack_level, task=task
|
||||
)
|
||||
|
||||
if not task:
|
||||
task_result = self.task_detector(candidate=candidate, reference=reference)
|
||||
task = task_result.task
|
||||
result = self.validator(
|
||||
instruction=self.prompts[task], reference=reference, candidate=candidate
|
||||
)
|
||||
|
||||
if self.data == "train":
|
||||
candidate_clean = (
|
||||
self.generate(reference=reference, attack_level=1, task=task)
|
||||
if target == None
|
||||
else target
|
||||
)
|
||||
result_clean = self.validator(
|
||||
instruction=self.prompts[task],
|
||||
reference=reference,
|
||||
candidate=candidate_clean,
|
||||
)
|
||||
return dspy.Prediction(
|
||||
reason=result["reasoning"],
|
||||
err=result["errors"],
|
||||
attack_prediction=result["risk_level"],
|
||||
attack_level=attack_level,
|
||||
clean_prediction=result_clean["risk_level"],
|
||||
)
|
||||
|
||||
return dspy.Prediction(
|
||||
reason=result["reasoning"],
|
||||
err=result["errors"],
|
||||
attack_prediction=result["risk_level"],
|
||||
)
|
||||
|
||||
def validator_metric(self, example, pred, trace=None):
|
||||
delta = scale_to_unit_interval(
|
||||
pred["attack_level"], num_levels=len(adversarial_attacks)
|
||||
)
|
||||
pred_clean_score = scale_to_unit_interval(
|
||||
pred["clean_prediction"], num_levels=len(adversarial_attacks)
|
||||
)
|
||||
pred_adv_score = scale_to_unit_interval(
|
||||
pred["attack_prediction"], num_levels=len(adversarial_attacks)
|
||||
)
|
||||
|
||||
absolute_consistency = (pred_clean_score**2) + (pred_adv_score - delta) ** 2
|
||||
relative_consistency = (pred_adv_score - pred_clean_score - delta) ** 2
|
||||
total_loss = absolute_consistency + relative_consistency
|
||||
metric_value = 1 - total_loss / 6
|
||||
|
||||
if self.debug:
|
||||
print(dspy.inspect_history(n=4))
|
||||
exit()
|
||||
|
||||
if (trace is not None) or (self.method == "finetune"):
|
||||
return metric_value >= self.threshold
|
||||
return metric_value
|
||||
|
||||
def save_results(self, df, method=None):
|
||||
df = df.where(pd.notnull(df), "None")
|
||||
df["lm_error_assessment"] = (
|
||||
df["lm_error_assessment"]
|
||||
.str.replace("\n\n", "\n", regex=False)
|
||||
.str.replace("\n \n", "\n", regex=False)
|
||||
.str.replace("\\n", "\n", regex=False)
|
||||
)
|
||||
|
||||
results_path = f"results/{method}/"
|
||||
os.makedirs(results_path, exist_ok=True)
|
||||
|
||||
if self.input_csv:
|
||||
csv_name = os.path.splitext(os.path.basename(self.input_csv))[0]
|
||||
file_path = f"{results_path}{self.model_name.split('/')[-1]}/{csv_name}.csv"
|
||||
else:
|
||||
file_path = (
|
||||
f"{results_path}{self.model_name.split('/')[-1]}/medval-bench.csv"
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
df.to_csv(file_path, index=False)
|
||||
print(f"\nResults saved to: {file_path}\n")
|
||||
57
medval/validator.py
Normal file
57
medval/validator.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import dspy
|
||||
from utils.prompts import (
|
||||
errors_prompt,
|
||||
error_categories,
|
||||
risk_levels_prompt,
|
||||
task_keys,
|
||||
instruction_mappings_prompt,
|
||||
)
|
||||
from typing import Literal, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorAssessment(BaseModel):
|
||||
error_occurrence: str = Field(
|
||||
description="The exact snippet of text in the candidate where the error appears."
|
||||
)
|
||||
error: str = Field(
|
||||
description="A concise explanation of why the snippet is an error."
|
||||
)
|
||||
category: str = Field(
|
||||
description=f"One of the 11 predefined error categories:\n{error_categories}"
|
||||
)
|
||||
reasoning: str = Field(
|
||||
description="Detailed reasoning outlining why this portion of the candidate is factually inconsistent with the reference."
|
||||
)
|
||||
|
||||
|
||||
class DetectTask(dspy.Signature):
|
||||
"""
|
||||
Detect the intended task from the reference text and the generated candidate
|
||||
"""
|
||||
|
||||
reference: str = dspy.InputField()
|
||||
candidate: str = dspy.InputField()
|
||||
task: Literal[*task_keys] = dspy.OutputField(
|
||||
description=instruction_mappings_prompt
|
||||
)
|
||||
|
||||
|
||||
class MedVAL_Validator(dspy.Signature):
|
||||
"""
|
||||
Evaluate a candidate in comparison to the reference composed by an expert.
|
||||
|
||||
Instructions:
|
||||
1. Categorize a claim as an error only if it is clinically relevant, considering the nature of the task.
|
||||
2. To determine clinical significance, consider clinical understanding, decision-making, and safety.
|
||||
3. Some tasks (e.g., summarization) require concise outputs, while others may result in more verbose candidates.
|
||||
- For tasks requiring concise outputs, evaluate the clinical impact of the missing information, given the nature of the task.
|
||||
- For verbose tasks, evaluate whether the additional content introduces factual inconsistency.
|
||||
"""
|
||||
|
||||
instruction: str = dspy.InputField()
|
||||
reference: str = dspy.InputField()
|
||||
candidate: str = dspy.InputField()
|
||||
# errors: str = dspy.OutputField(description=errors_prompt)
|
||||
errors: List[ErrorAssessment] = dspy.OutputField(description=errors_prompt)
|
||||
risk_level: Literal[1, 2, 3, 4] = dspy.OutputField(description=risk_levels_prompt)
|
||||
Reference in New Issue
Block a user