368 lines
13 KiB
Python
368 lines
13 KiB
Python
# taken from https://github.com/seacowx/OpenToM/blob/main/src/evaluate/opentom_evaluator.py
|
|
# modified for usability
|
|
|
|
from collections import defaultdict
|
|
import json
|
|
import traceback
|
|
|
|
|
|
class OpenToMEvaluatorDspy:
|
|
def __init__(self, model_name="") -> None:
|
|
self.true_positives = defaultdict(lambda: 0)
|
|
self.false_positives = defaultdict(lambda: 0)
|
|
self.false_negatives = defaultdict(lambda: 0)
|
|
self.model_name = model_name
|
|
|
|
def dspy_metric(self, example, pred_answer, trace=None):
|
|
type = example.type
|
|
|
|
eval_result = self.check_answer(example, pred_answer.answer)
|
|
if (
|
|
eval_result == None
|
|
): # Hm what is the correct value to return as a dspy metric when there's an invalid example?
|
|
return None
|
|
gt, pred = eval_result # ground truth answer class, predicted answer class
|
|
|
|
# store positive/negative results by class so we can calculate the f1 scores later
|
|
if gt == pred:
|
|
self.true_positives[f"{type}_{pred}"] += 1
|
|
else:
|
|
self.false_positives[f"{type}_{pred}"] += 1
|
|
self.false_negatives[f"{type}_{gt}"] += 1
|
|
|
|
# print("done", example.type, gt, pred, example.answer, pred_answer.answer)
|
|
|
|
return gt == pred
|
|
|
|
# this method was added to make dspy evaluation easier
|
|
def check_answer(
|
|
self,
|
|
example,
|
|
pred_answer,
|
|
cot_flag=False,
|
|
perspective="all",
|
|
):
|
|
mover, affected_char, eoi, original_place, move_to_place = json.loads(
|
|
example.plot_info
|
|
).values()
|
|
|
|
cur_question_type = example.type
|
|
question_content = example.question
|
|
|
|
gt_answer = example.answer.strip()
|
|
pred_answer = pred_answer.strip()
|
|
|
|
# NOTE: evaluate based on the character
|
|
if perspective == "observer":
|
|
if mover in question_content and affected_char not in question_content:
|
|
return None
|
|
|
|
if mover in question_content and affected_char in question_content:
|
|
question_tokens = (
|
|
question_content.replace("'s", "").replace(",", "").split()
|
|
)
|
|
|
|
mover_idx = question_tokens.index(mover)
|
|
affected_char_idx = question_tokens.index(affected_char)
|
|
|
|
if mover_idx < affected_char_idx:
|
|
return None
|
|
|
|
elif perspective == "mover":
|
|
if mover not in question_content and affected_char in question_content:
|
|
return None
|
|
|
|
if mover in question_content and affected_char in question_content:
|
|
question_tokens = (
|
|
question_content.replace("'s", "").replace(",", "").split()
|
|
)
|
|
|
|
mover_idx = question_tokens.index(mover)
|
|
affected_char_idx = question_tokens.index(affected_char)
|
|
|
|
if mover_idx > affected_char_idx:
|
|
return None
|
|
|
|
if cot_flag:
|
|
pred_answer = self.parse_cot_answer(pred_answer)
|
|
|
|
if cur_question_type == "location-fo-coarse":
|
|
gt, pred = self.check_answer_for_cg_location(pred_answer, gt_answer)
|
|
return gt, pred
|
|
|
|
elif cur_question_type == "location-fo-fine":
|
|
gt, pred = self.check_answer_for_fg_location(
|
|
pred_answer, gt_answer, original_place, move_to_place
|
|
)
|
|
return gt, pred
|
|
|
|
elif cur_question_type == "location-so-coarse":
|
|
gt, pred = self.check_answer_for_cg_location(pred_answer, gt_answer)
|
|
return gt, pred
|
|
|
|
elif cur_question_type == "location-so-fine":
|
|
gt, pred = self.check_answer_for_fg_location(
|
|
pred_answer, gt_answer, original_place, move_to_place
|
|
)
|
|
return gt, pred
|
|
|
|
elif cur_question_type == "multihop-fo":
|
|
if "fullness" in question_content:
|
|
gt, pred = self.check_fullness_answer(pred_answer, gt_answer)
|
|
return gt, pred
|
|
|
|
elif "accessibility" in question_content:
|
|
if "|" in gt_answer:
|
|
gt_answer = "equally accessible"
|
|
|
|
if isinstance(gt_answer, list):
|
|
gt_answer = [ele for ele in gt_answer if ele != "corrupted"]
|
|
assert len(gt_answer) == 1
|
|
gt_answer = gt_answer[0]
|
|
|
|
gt, pred = self.check_accessibility_answer(pred_answer, gt_answer)
|
|
return gt, pred
|
|
|
|
elif cur_question_type == "multihop-so":
|
|
if "fullness" in question_content:
|
|
gt, pred = self.check_fullness_answer(pred_answer, gt_answer)
|
|
return gt, pred
|
|
|
|
elif "accessibility" in question_content:
|
|
if "|" in gt_answer:
|
|
gt_answer = "equally accessible"
|
|
|
|
if isinstance(gt_answer, list):
|
|
gt_answer = [ele for ele in gt_answer if ele != "corrupted"]
|
|
assert len(gt_answer) == 1
|
|
gt_answer = gt_answer[0]
|
|
|
|
gt, pred = self.check_accessibility_answer(pred_answer, gt_answer)
|
|
return gt, pred
|
|
|
|
elif cur_question_type == "attitude":
|
|
gt, pred = self.check_attitude_answer(pred_answer, gt_answer)
|
|
return gt, pred
|
|
|
|
def f1_score(self):
|
|
true_positives = self.true_positives
|
|
false_positives = self.false_positives
|
|
false_negatives = self.false_negatives
|
|
f1_scores = defaultdict(lambda: {"by_class": {}})
|
|
|
|
for _class in (
|
|
true_positives.keys() | false_positives.keys() | false_negatives.keys()
|
|
):
|
|
question_type, _ = _class.split("_")
|
|
class_true_positives = true_positives[_class]
|
|
class_false_positives = false_positives[_class]
|
|
class_false_negatives = false_negatives[_class]
|
|
class_precision = (
|
|
class_true_positives / (class_true_positives + class_false_positives)
|
|
if class_true_positives > 0.0
|
|
else 0.0
|
|
) # avoid dividing by zero
|
|
class_recall = (
|
|
class_true_positives / (class_true_positives + class_false_negatives)
|
|
if class_true_positives > 0.0
|
|
else 0.0
|
|
)
|
|
class_f1_score = (
|
|
(2 * class_precision * class_recall) / (class_precision + class_recall)
|
|
if class_precision > 0.0 or class_recall > 0.0
|
|
else 0.0
|
|
)
|
|
f1_scores[question_type]["by_class"][_class] = class_f1_score
|
|
|
|
for question_type, type_f1_scores in f1_scores.items():
|
|
type_f1_scores = type_f1_scores["by_class"]
|
|
macro_averaged_f1_score = sum(list(type_f1_scores.values())) / len(
|
|
type_f1_scores
|
|
)
|
|
f1_scores[question_type]["macro_averaged"] = macro_averaged_f1_score
|
|
|
|
return f1_scores
|
|
|
|
# pretty print macro averaged f1 scores for each question type
|
|
def print_f1_results(self, round_decimal=2, print_header=False):
|
|
f1_scores = self.f1_score()
|
|
if print_header:
|
|
print("Macro Averaged F1 Scores by question type")
|
|
|
|
print(self.model_name, end=" - ")
|
|
for question_type, type_f1_scores in f1_scores.items():
|
|
print(
|
|
f"{question_type}: {round(type_f1_scores['macro_averaged'], ndigits=round_decimal + 2) * 100}",
|
|
end="\t",
|
|
)
|
|
print()
|
|
|
|
@staticmethod
|
|
def remove_determinant(word: str) -> str:
|
|
determinants = ["a", "an", "the"]
|
|
for det in determinants:
|
|
if word.startswith(det):
|
|
return word[len(det) :].strip()
|
|
return word
|
|
|
|
@staticmethod
|
|
def compute_lexical_overlap(pred: str, location: str) -> float:
|
|
pred = pred.lower().replace("_", " ").replace("'s", "")
|
|
location = location.lower().replace("_", " ").replace("'s", "")
|
|
score = 0
|
|
pred = pred.replace(".", "").split()
|
|
location = location.split()
|
|
visited_word = []
|
|
|
|
for word in pred:
|
|
if word in location and word not in visited_word:
|
|
score += 1
|
|
visited_word.append(word)
|
|
|
|
return score / len(location)
|
|
|
|
def parse_cot_answer(self, answer: str) -> str:
|
|
# cot typically generate answer in the last sentence or paragraph
|
|
if "\n" in answer:
|
|
answer = answer.split("\n")[-1]
|
|
else:
|
|
answer = answer.split("Therefore")[-1]
|
|
return answer
|
|
|
|
def check_answer_for_fg_location(
|
|
self, prediction: str, answer: str, original_place: str, move_to_place: str
|
|
) -> list:
|
|
# truncate prediction as some of them contain explanations
|
|
answer = self.remove_determinant(answer).lower()
|
|
original_place = self.remove_determinant(original_place).lower()
|
|
move_to_place = self.remove_determinant(move_to_place).lower()
|
|
gt_label, pred_label = None, None
|
|
original_place_score = self.compute_lexical_overlap(prediction, original_place)
|
|
move_to_place_score = self.compute_lexical_overlap(prediction, move_to_place)
|
|
|
|
if original_place_score == move_to_place_score:
|
|
pred_label = 3
|
|
if original_place_score > move_to_place_score:
|
|
pred_label = 1
|
|
elif original_place_score < move_to_place_score:
|
|
pred_label = 2
|
|
|
|
if original_place == answer:
|
|
gt_label = 1
|
|
elif move_to_place == answer:
|
|
gt_label = 2
|
|
|
|
return [gt_label, pred_label]
|
|
|
|
def check_answer_for_cg_location(self, prediction: str, answer: str) -> list:
|
|
prediction = prediction.lower()
|
|
answer = answer.lower()
|
|
|
|
if "no" in prediction and "yes" not in prediction:
|
|
pred_label = 0
|
|
elif "yes" in prediction and "no" not in prediction:
|
|
pred_label = 1
|
|
else:
|
|
pred_label = -1
|
|
|
|
if "no" in answer:
|
|
gt_label = 0
|
|
elif "yes" in answer:
|
|
gt_label = 1
|
|
|
|
return [gt_label, pred_label]
|
|
|
|
def check_fullness_answer(self, prediction: str, answer: str) -> list:
|
|
prediction = prediction.replace(".", "").lower()
|
|
less_full_answer_list = ["less full", "emptier", "more empty"]
|
|
more_full_answer_list = ["more full", "fuller"]
|
|
pred_label, gt_label = None, None
|
|
for less_full_ans in less_full_answer_list:
|
|
if less_full_ans in prediction:
|
|
pred_label = 1
|
|
|
|
if not pred_label:
|
|
for more_full_ans in more_full_answer_list:
|
|
if more_full_ans in prediction:
|
|
pred_label = 2
|
|
|
|
if not pred_label:
|
|
if "equally full" in prediction:
|
|
pred_label = 3
|
|
|
|
if not pred_label:
|
|
pred_label = -1 # corrupted
|
|
|
|
if answer == "less full":
|
|
gt_label = 1
|
|
elif answer == "more full":
|
|
gt_label = 2
|
|
elif answer == "equally full":
|
|
gt_label = 3
|
|
|
|
return [gt_label, pred_label]
|
|
|
|
def check_accessibility_answer(self, prediction: str, answer: str) -> list:
|
|
prediction = prediction.replace(".", "").lower()
|
|
pred_label, gt_label = None, None
|
|
if "more accessible" in prediction:
|
|
pred_label = 1
|
|
elif "less accessible" in prediction:
|
|
pred_label = 2
|
|
elif "equally accessible" in prediction:
|
|
pred_label = 3
|
|
else:
|
|
pred_label = -1 # corrupted
|
|
|
|
if answer == "more accessible":
|
|
gt_label = 1
|
|
elif answer == "less accessible":
|
|
gt_label = 2
|
|
else:
|
|
gt_label = 3
|
|
|
|
return [gt_label, pred_label]
|
|
|
|
def check_attitude_answer(self, prediction: str, answer: str) -> list:
|
|
prediction = prediction.lower()
|
|
answer = answer.lower()
|
|
answer_map = {"a": "positive", "b": "neutral", "c": "negative"}
|
|
prediction_token = (
|
|
prediction.split("\n\n")[-1].split(":")[-1].split(".")[0].strip().lower()
|
|
)
|
|
gt_label, pred_label = None, None
|
|
|
|
if answer == "positive":
|
|
gt_label = 1
|
|
elif answer == "negative":
|
|
gt_label = 2
|
|
else:
|
|
gt_label = 3
|
|
|
|
try:
|
|
prediction = answer_map[prediction_token]
|
|
if prediction == "positive":
|
|
pred_label = 1
|
|
elif prediction == "negative":
|
|
pred_label = 2
|
|
else:
|
|
pred_label = 3
|
|
|
|
except:
|
|
if "positive" in prediction_token and "negative" in prediction_token:
|
|
pred_label = -1
|
|
elif "positive" in prediction_token and "neutral" in prediction_token:
|
|
pred_label = -1
|
|
elif "neutral" in prediction_token and "negative" in prediction_token:
|
|
pred_label = -1
|
|
elif "positive" in prediction_token:
|
|
pred_label = 1
|
|
elif "negative" in prediction_token:
|
|
pred_label = 2
|
|
elif "neutral" in prediction_token:
|
|
pred_label = 3
|
|
else:
|
|
pred_label = -1
|
|
|
|
return [gt_label, pred_label]
|