58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
import dspy
|
|
from utils.prompts import (
|
|
errors_prompt,
|
|
error_categories,
|
|
risk_levels_prompt,
|
|
task_keys,
|
|
instruction_mappings_prompt,
|
|
)
|
|
from typing import Literal, List
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class ErrorAssessment(BaseModel):
|
|
error_occurrence: str = Field(
|
|
description="The exact snippet of text in the candidate where the error appears."
|
|
)
|
|
error: str = Field(
|
|
description="A concise explanation of why the snippet is an error."
|
|
)
|
|
category: str = Field(
|
|
description=f"One of the 11 predefined error categories:\n{error_categories}"
|
|
)
|
|
reasoning: str = Field(
|
|
description="Detailed reasoning outlining why this portion of the candidate is factually inconsistent with the reference."
|
|
)
|
|
|
|
|
|
class DetectTask(dspy.Signature):
|
|
"""
|
|
Detect the intended task from the reference text and the generated candidate
|
|
"""
|
|
|
|
reference: str = dspy.InputField()
|
|
candidate: str = dspy.InputField()
|
|
task: Literal[*task_keys] = dspy.OutputField(
|
|
description=instruction_mappings_prompt
|
|
)
|
|
|
|
|
|
class MedVAL_Validator(dspy.Signature):
|
|
"""
|
|
Evaluate a candidate in comparison to the reference composed by an expert.
|
|
|
|
Instructions:
|
|
1. Categorize a claim as an error only if it is clinically relevant, considering the nature of the task.
|
|
2. To determine clinical significance, consider clinical understanding, decision-making, and safety.
|
|
3. Some tasks (e.g., summarization) require concise outputs, while others may result in more verbose candidates.
|
|
- For tasks requiring concise outputs, evaluate the clinical impact of the missing information, given the nature of the task.
|
|
- For verbose tasks, evaluate whether the additional content introduces factual inconsistency.
|
|
"""
|
|
|
|
instruction: str = dspy.InputField()
|
|
reference: str = dspy.InputField()
|
|
candidate: str = dspy.InputField()
|
|
# errors: str = dspy.OutputField(description=errors_prompt)
|
|
errors: List[ErrorAssessment] = dspy.OutputField(description=errors_prompt)
|
|
risk_level: Literal[1, 2, 3, 4] = dspy.OutputField(description=risk_levels_prompt)
|