Files
CoTWithThoughtSimplifiedBal…/opentom_evaluator.py

368 lines
13 KiB
Python

# taken from https://github.com/seacowx/OpenToM/blob/main/src/evaluate/opentom_evaluator.py
# modified for usability
from collections import defaultdict
import json
import traceback
class OpenToMEvaluatorDspy:
def __init__(self, model_name="") -> None:
self.true_positives = defaultdict(lambda: 0)
self.false_positives = defaultdict(lambda: 0)
self.false_negatives = defaultdict(lambda: 0)
self.model_name = model_name
def dspy_metric(self, example, pred_answer, trace=None):
type = example.type
eval_result = self.check_answer(example, pred_answer.answer)
if (
eval_result == None
): # Hm what is the correct value to return as a dspy metric when there's an invalid example?
return None
gt, pred = eval_result # ground truth answer class, predicted answer class
# store positive/negative results by class so we can calculate the f1 scores later
if gt == pred:
self.true_positives[f"{type}_{pred}"] += 1
else:
self.false_positives[f"{type}_{pred}"] += 1
self.false_negatives[f"{type}_{gt}"] += 1
# print("done", example.type, gt, pred, example.answer, pred_answer.answer)
return gt == pred
# this method was added to make dspy evaluation easier
def check_answer(
self,
example,
pred_answer,
cot_flag=False,
perspective="all",
):
mover, affected_char, eoi, original_place, move_to_place = json.loads(
example.plot_info
).values()
cur_question_type = example.type
question_content = example.question
gt_answer = example.answer.strip()
pred_answer = pred_answer.strip()
# NOTE: evaluate based on the character
if perspective == "observer":
if mover in question_content and affected_char not in question_content:
return None
if mover in question_content and affected_char in question_content:
question_tokens = (
question_content.replace("'s", "").replace(",", "").split()
)
mover_idx = question_tokens.index(mover)
affected_char_idx = question_tokens.index(affected_char)
if mover_idx < affected_char_idx:
return None
elif perspective == "mover":
if mover not in question_content and affected_char in question_content:
return None
if mover in question_content and affected_char in question_content:
question_tokens = (
question_content.replace("'s", "").replace(",", "").split()
)
mover_idx = question_tokens.index(mover)
affected_char_idx = question_tokens.index(affected_char)
if mover_idx > affected_char_idx:
return None
if cot_flag:
pred_answer = self.parse_cot_answer(pred_answer)
if cur_question_type == "location-fo-coarse":
gt, pred = self.check_answer_for_cg_location(pred_answer, gt_answer)
return gt, pred
elif cur_question_type == "location-fo-fine":
gt, pred = self.check_answer_for_fg_location(
pred_answer, gt_answer, original_place, move_to_place
)
return gt, pred
elif cur_question_type == "location-so-coarse":
gt, pred = self.check_answer_for_cg_location(pred_answer, gt_answer)
return gt, pred
elif cur_question_type == "location-so-fine":
gt, pred = self.check_answer_for_fg_location(
pred_answer, gt_answer, original_place, move_to_place
)
return gt, pred
elif cur_question_type == "multihop-fo":
if "fullness" in question_content:
gt, pred = self.check_fullness_answer(pred_answer, gt_answer)
return gt, pred
elif "accessibility" in question_content:
if "|" in gt_answer:
gt_answer = "equally accessible"
if isinstance(gt_answer, list):
gt_answer = [ele for ele in gt_answer if ele != "corrupted"]
assert len(gt_answer) == 1
gt_answer = gt_answer[0]
gt, pred = self.check_accessibility_answer(pred_answer, gt_answer)
return gt, pred
elif cur_question_type == "multihop-so":
if "fullness" in question_content:
gt, pred = self.check_fullness_answer(pred_answer, gt_answer)
return gt, pred
elif "accessibility" in question_content:
if "|" in gt_answer:
gt_answer = "equally accessible"
if isinstance(gt_answer, list):
gt_answer = [ele for ele in gt_answer if ele != "corrupted"]
assert len(gt_answer) == 1
gt_answer = gt_answer[0]
gt, pred = self.check_accessibility_answer(pred_answer, gt_answer)
return gt, pred
elif cur_question_type == "attitude":
gt, pred = self.check_attitude_answer(pred_answer, gt_answer)
return gt, pred
def f1_score(self):
true_positives = self.true_positives
false_positives = self.false_positives
false_negatives = self.false_negatives
f1_scores = defaultdict(lambda: {"by_class": {}})
for _class in (
true_positives.keys() | false_positives.keys() | false_negatives.keys()
):
question_type, _ = _class.split("_")
class_true_positives = true_positives[_class]
class_false_positives = false_positives[_class]
class_false_negatives = false_negatives[_class]
class_precision = (
class_true_positives / (class_true_positives + class_false_positives)
if class_true_positives > 0.0
else 0.0
) # avoid dividing by zero
class_recall = (
class_true_positives / (class_true_positives + class_false_negatives)
if class_true_positives > 0.0
else 0.0
)
class_f1_score = (
(2 * class_precision * class_recall) / (class_precision + class_recall)
if class_precision > 0.0 or class_recall > 0.0
else 0.0
)
f1_scores[question_type]["by_class"][_class] = class_f1_score
for question_type, type_f1_scores in f1_scores.items():
type_f1_scores = type_f1_scores["by_class"]
macro_averaged_f1_score = sum(list(type_f1_scores.values())) / len(
type_f1_scores
)
f1_scores[question_type]["macro_averaged"] = macro_averaged_f1_score
return f1_scores
# pretty print macro averaged f1 scores for each question type
def print_f1_results(self, round_decimal=2, print_header=False):
f1_scores = self.f1_score()
if print_header:
print("Macro Averaged F1 Scores by question type")
print(self.model_name, end=" - ")
for question_type, type_f1_scores in f1_scores.items():
print(
f"{question_type}: {round(type_f1_scores['macro_averaged'], ndigits=round_decimal + 2) * 100}",
end="\t",
)
print()
@staticmethod
def remove_determinant(word: str) -> str:
determinants = ["a", "an", "the"]
for det in determinants:
if word.startswith(det):
return word[len(det) :].strip()
return word
@staticmethod
def compute_lexical_overlap(pred: str, location: str) -> float:
pred = pred.lower().replace("_", " ").replace("'s", "")
location = location.lower().replace("_", " ").replace("'s", "")
score = 0
pred = pred.replace(".", "").split()
location = location.split()
visited_word = []
for word in pred:
if word in location and word not in visited_word:
score += 1
visited_word.append(word)
return score / len(location)
def parse_cot_answer(self, answer: str) -> str:
# cot typically generate answer in the last sentence or paragraph
if "\n" in answer:
answer = answer.split("\n")[-1]
else:
answer = answer.split("Therefore")[-1]
return answer
def check_answer_for_fg_location(
self, prediction: str, answer: str, original_place: str, move_to_place: str
) -> list:
# truncate prediction as some of them contain explanations
answer = self.remove_determinant(answer).lower()
original_place = self.remove_determinant(original_place).lower()
move_to_place = self.remove_determinant(move_to_place).lower()
gt_label, pred_label = None, None
original_place_score = self.compute_lexical_overlap(prediction, original_place)
move_to_place_score = self.compute_lexical_overlap(prediction, move_to_place)
if original_place_score == move_to_place_score:
pred_label = 3
if original_place_score > move_to_place_score:
pred_label = 1
elif original_place_score < move_to_place_score:
pred_label = 2
if original_place == answer:
gt_label = 1
elif move_to_place == answer:
gt_label = 2
return [gt_label, pred_label]
def check_answer_for_cg_location(self, prediction: str, answer: str) -> list:
prediction = prediction.lower()
answer = answer.lower()
if "no" in prediction and "yes" not in prediction:
pred_label = 0
elif "yes" in prediction and "no" not in prediction:
pred_label = 1
else:
pred_label = -1
if "no" in answer:
gt_label = 0
elif "yes" in answer:
gt_label = 1
return [gt_label, pred_label]
def check_fullness_answer(self, prediction: str, answer: str) -> list:
prediction = prediction.replace(".", "").lower()
less_full_answer_list = ["less full", "emptier", "more empty"]
more_full_answer_list = ["more full", "fuller"]
pred_label, gt_label = None, None
for less_full_ans in less_full_answer_list:
if less_full_ans in prediction:
pred_label = 1
if not pred_label:
for more_full_ans in more_full_answer_list:
if more_full_ans in prediction:
pred_label = 2
if not pred_label:
if "equally full" in prediction:
pred_label = 3
if not pred_label:
pred_label = -1 # corrupted
if answer == "less full":
gt_label = 1
elif answer == "more full":
gt_label = 2
elif answer == "equally full":
gt_label = 3
return [gt_label, pred_label]
def check_accessibility_answer(self, prediction: str, answer: str) -> list:
prediction = prediction.replace(".", "").lower()
pred_label, gt_label = None, None
if "more accessible" in prediction:
pred_label = 1
elif "less accessible" in prediction:
pred_label = 2
elif "equally accessible" in prediction:
pred_label = 3
else:
pred_label = -1 # corrupted
if answer == "more accessible":
gt_label = 1
elif answer == "less accessible":
gt_label = 2
else:
gt_label = 3
return [gt_label, pred_label]
def check_attitude_answer(self, prediction: str, answer: str) -> list:
prediction = prediction.lower()
answer = answer.lower()
answer_map = {"a": "positive", "b": "neutral", "c": "negative"}
prediction_token = (
prediction.split("\n\n")[-1].split(":")[-1].split(".")[0].strip().lower()
)
gt_label, pred_label = None, None
if answer == "positive":
gt_label = 1
elif answer == "negative":
gt_label = 2
else:
gt_label = 3
try:
prediction = answer_map[prediction_token]
if prediction == "positive":
pred_label = 1
elif prediction == "negative":
pred_label = 2
else:
pred_label = 3
except:
if "positive" in prediction_token and "negative" in prediction_token:
pred_label = -1
elif "positive" in prediction_token and "neutral" in prediction_token:
pred_label = -1
elif "neutral" in prediction_token and "negative" in prediction_token:
pred_label = -1
elif "positive" in prediction_token:
pred_label = 1
elif "negative" in prediction_token:
pred_label = 2
elif "neutral" in prediction_token:
pred_label = 3
else:
pred_label = -1
return [gt_label, pred_label]