Compiled CoTWithThoughtSimplifiedBaleen with bootstrap_fewshot_with_random_search for multihop-so
This commit is contained in:
367
opentom_evaluator.py
Normal file
367
opentom_evaluator.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# taken from https://github.com/seacowx/OpenToM/blob/main/src/evaluate/opentom_evaluator.py
|
||||
# modified for usability
|
||||
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import traceback
|
||||
|
||||
|
||||
class OpenToMEvaluatorDspy:
|
||||
def __init__(self, model_name="") -> None:
|
||||
self.true_positives = defaultdict(lambda: 0)
|
||||
self.false_positives = defaultdict(lambda: 0)
|
||||
self.false_negatives = defaultdict(lambda: 0)
|
||||
self.model_name = model_name
|
||||
|
||||
def dspy_metric(self, example, pred_answer, trace=None):
|
||||
type = example.type
|
||||
|
||||
eval_result = self.check_answer(example, pred_answer.answer)
|
||||
if (
|
||||
eval_result == None
|
||||
): # Hm what is the correct value to return as a dspy metric when there's an invalid example?
|
||||
return None
|
||||
gt, pred = eval_result # ground truth answer class, predicted answer class
|
||||
|
||||
# store positive/negative results by class so we can calculate the f1 scores later
|
||||
if gt == pred:
|
||||
self.true_positives[f"{type}_{pred}"] += 1
|
||||
else:
|
||||
self.false_positives[f"{type}_{pred}"] += 1
|
||||
self.false_negatives[f"{type}_{gt}"] += 1
|
||||
|
||||
# print("done", example.type, gt, pred, example.answer, pred_answer.answer)
|
||||
|
||||
return gt == pred
|
||||
|
||||
# this method was added to make dspy evaluation easier
|
||||
def check_answer(
|
||||
self,
|
||||
example,
|
||||
pred_answer,
|
||||
cot_flag=False,
|
||||
perspective="all",
|
||||
):
|
||||
mover, affected_char, eoi, original_place, move_to_place = json.loads(
|
||||
example.plot_info
|
||||
).values()
|
||||
|
||||
cur_question_type = example.type
|
||||
question_content = example.question
|
||||
|
||||
gt_answer = example.answer.strip()
|
||||
pred_answer = pred_answer.strip()
|
||||
|
||||
# NOTE: evaluate based on the character
|
||||
if perspective == "observer":
|
||||
if mover in question_content and affected_char not in question_content:
|
||||
return None
|
||||
|
||||
if mover in question_content and affected_char in question_content:
|
||||
question_tokens = (
|
||||
question_content.replace("'s", "").replace(",", "").split()
|
||||
)
|
||||
|
||||
mover_idx = question_tokens.index(mover)
|
||||
affected_char_idx = question_tokens.index(affected_char)
|
||||
|
||||
if mover_idx < affected_char_idx:
|
||||
return None
|
||||
|
||||
elif perspective == "mover":
|
||||
if mover not in question_content and affected_char in question_content:
|
||||
return None
|
||||
|
||||
if mover in question_content and affected_char in question_content:
|
||||
question_tokens = (
|
||||
question_content.replace("'s", "").replace(",", "").split()
|
||||
)
|
||||
|
||||
mover_idx = question_tokens.index(mover)
|
||||
affected_char_idx = question_tokens.index(affected_char)
|
||||
|
||||
if mover_idx > affected_char_idx:
|
||||
return None
|
||||
|
||||
if cot_flag:
|
||||
pred_answer = self.parse_cot_answer(pred_answer)
|
||||
|
||||
if cur_question_type == "location-fo-coarse":
|
||||
gt, pred = self.check_answer_for_cg_location(pred_answer, gt_answer)
|
||||
return gt, pred
|
||||
|
||||
elif cur_question_type == "location-fo-fine":
|
||||
gt, pred = self.check_answer_for_fg_location(
|
||||
pred_answer, gt_answer, original_place, move_to_place
|
||||
)
|
||||
return gt, pred
|
||||
|
||||
elif cur_question_type == "location-so-coarse":
|
||||
gt, pred = self.check_answer_for_cg_location(pred_answer, gt_answer)
|
||||
return gt, pred
|
||||
|
||||
elif cur_question_type == "location-so-fine":
|
||||
gt, pred = self.check_answer_for_fg_location(
|
||||
pred_answer, gt_answer, original_place, move_to_place
|
||||
)
|
||||
return gt, pred
|
||||
|
||||
elif cur_question_type == "multihop-fo":
|
||||
if "fullness" in question_content:
|
||||
gt, pred = self.check_fullness_answer(pred_answer, gt_answer)
|
||||
return gt, pred
|
||||
|
||||
elif "accessibility" in question_content:
|
||||
if "|" in gt_answer:
|
||||
gt_answer = "equally accessible"
|
||||
|
||||
if isinstance(gt_answer, list):
|
||||
gt_answer = [ele for ele in gt_answer if ele != "corrupted"]
|
||||
assert len(gt_answer) == 1
|
||||
gt_answer = gt_answer[0]
|
||||
|
||||
gt, pred = self.check_accessibility_answer(pred_answer, gt_answer)
|
||||
return gt, pred
|
||||
|
||||
elif cur_question_type == "multihop-so":
|
||||
if "fullness" in question_content:
|
||||
gt, pred = self.check_fullness_answer(pred_answer, gt_answer)
|
||||
return gt, pred
|
||||
|
||||
elif "accessibility" in question_content:
|
||||
if "|" in gt_answer:
|
||||
gt_answer = "equally accessible"
|
||||
|
||||
if isinstance(gt_answer, list):
|
||||
gt_answer = [ele for ele in gt_answer if ele != "corrupted"]
|
||||
assert len(gt_answer) == 1
|
||||
gt_answer = gt_answer[0]
|
||||
|
||||
gt, pred = self.check_accessibility_answer(pred_answer, gt_answer)
|
||||
return gt, pred
|
||||
|
||||
elif cur_question_type == "attitude":
|
||||
gt, pred = self.check_attitude_answer(pred_answer, gt_answer)
|
||||
return gt, pred
|
||||
|
||||
def f1_score(self):
|
||||
true_positives = self.true_positives
|
||||
false_positives = self.false_positives
|
||||
false_negatives = self.false_negatives
|
||||
f1_scores = defaultdict(lambda: {"by_class": {}})
|
||||
|
||||
for _class in (
|
||||
true_positives.keys() | false_positives.keys() | false_negatives.keys()
|
||||
):
|
||||
question_type, _ = _class.split("_")
|
||||
class_true_positives = true_positives[_class]
|
||||
class_false_positives = false_positives[_class]
|
||||
class_false_negatives = false_negatives[_class]
|
||||
class_precision = (
|
||||
class_true_positives / (class_true_positives + class_false_positives)
|
||||
if class_true_positives > 0.0
|
||||
else 0.0
|
||||
) # avoid dividing by zero
|
||||
class_recall = (
|
||||
class_true_positives / (class_true_positives + class_false_negatives)
|
||||
if class_true_positives > 0.0
|
||||
else 0.0
|
||||
)
|
||||
class_f1_score = (
|
||||
(2 * class_precision * class_recall) / (class_precision + class_recall)
|
||||
if class_precision > 0.0 or class_recall > 0.0
|
||||
else 0.0
|
||||
)
|
||||
f1_scores[question_type]["by_class"][_class] = class_f1_score
|
||||
|
||||
for question_type, type_f1_scores in f1_scores.items():
|
||||
type_f1_scores = type_f1_scores["by_class"]
|
||||
macro_averaged_f1_score = sum(list(type_f1_scores.values())) / len(
|
||||
type_f1_scores
|
||||
)
|
||||
f1_scores[question_type]["macro_averaged"] = macro_averaged_f1_score
|
||||
|
||||
return f1_scores
|
||||
|
||||
# pretty print macro averaged f1 scores for each question type
|
||||
def print_f1_results(self, round_decimal=2, print_header=False):
|
||||
f1_scores = self.f1_score()
|
||||
if print_header:
|
||||
print("Macro Averaged F1 Scores by question type")
|
||||
|
||||
print(self.model_name, end=" - ")
|
||||
for question_type, type_f1_scores in f1_scores.items():
|
||||
print(
|
||||
f"{question_type}: {round(type_f1_scores['macro_averaged'], ndigits=round_decimal + 2) * 100}",
|
||||
end="\t",
|
||||
)
|
||||
print()
|
||||
|
||||
@staticmethod
|
||||
def remove_determinant(word: str) -> str:
|
||||
determinants = ["a", "an", "the"]
|
||||
for det in determinants:
|
||||
if word.startswith(det):
|
||||
return word[len(det) :].strip()
|
||||
return word
|
||||
|
||||
@staticmethod
|
||||
def compute_lexical_overlap(pred: str, location: str) -> float:
|
||||
pred = pred.lower().replace("_", " ").replace("'s", "")
|
||||
location = location.lower().replace("_", " ").replace("'s", "")
|
||||
score = 0
|
||||
pred = pred.replace(".", "").split()
|
||||
location = location.split()
|
||||
visited_word = []
|
||||
|
||||
for word in pred:
|
||||
if word in location and word not in visited_word:
|
||||
score += 1
|
||||
visited_word.append(word)
|
||||
|
||||
return score / len(location)
|
||||
|
||||
def parse_cot_answer(self, answer: str) -> str:
|
||||
# cot typically generate answer in the last sentence or paragraph
|
||||
if "\n" in answer:
|
||||
answer = answer.split("\n")[-1]
|
||||
else:
|
||||
answer = answer.split("Therefore")[-1]
|
||||
return answer
|
||||
|
||||
def check_answer_for_fg_location(
|
||||
self, prediction: str, answer: str, original_place: str, move_to_place: str
|
||||
) -> list:
|
||||
# truncate prediction as some of them contain explanations
|
||||
answer = self.remove_determinant(answer).lower()
|
||||
original_place = self.remove_determinant(original_place).lower()
|
||||
move_to_place = self.remove_determinant(move_to_place).lower()
|
||||
gt_label, pred_label = None, None
|
||||
original_place_score = self.compute_lexical_overlap(prediction, original_place)
|
||||
move_to_place_score = self.compute_lexical_overlap(prediction, move_to_place)
|
||||
|
||||
if original_place_score == move_to_place_score:
|
||||
pred_label = 3
|
||||
if original_place_score > move_to_place_score:
|
||||
pred_label = 1
|
||||
elif original_place_score < move_to_place_score:
|
||||
pred_label = 2
|
||||
|
||||
if original_place == answer:
|
||||
gt_label = 1
|
||||
elif move_to_place == answer:
|
||||
gt_label = 2
|
||||
|
||||
return [gt_label, pred_label]
|
||||
|
||||
def check_answer_for_cg_location(self, prediction: str, answer: str) -> list:
|
||||
prediction = prediction.lower()
|
||||
answer = answer.lower()
|
||||
|
||||
if "no" in prediction and "yes" not in prediction:
|
||||
pred_label = 0
|
||||
elif "yes" in prediction and "no" not in prediction:
|
||||
pred_label = 1
|
||||
else:
|
||||
pred_label = -1
|
||||
|
||||
if "no" in answer:
|
||||
gt_label = 0
|
||||
elif "yes" in answer:
|
||||
gt_label = 1
|
||||
|
||||
return [gt_label, pred_label]
|
||||
|
||||
def check_fullness_answer(self, prediction: str, answer: str) -> list:
|
||||
prediction = prediction.replace(".", "").lower()
|
||||
less_full_answer_list = ["less full", "emptier", "more empty"]
|
||||
more_full_answer_list = ["more full", "fuller"]
|
||||
pred_label, gt_label = None, None
|
||||
for less_full_ans in less_full_answer_list:
|
||||
if less_full_ans in prediction:
|
||||
pred_label = 1
|
||||
|
||||
if not pred_label:
|
||||
for more_full_ans in more_full_answer_list:
|
||||
if more_full_ans in prediction:
|
||||
pred_label = 2
|
||||
|
||||
if not pred_label:
|
||||
if "equally full" in prediction:
|
||||
pred_label = 3
|
||||
|
||||
if not pred_label:
|
||||
pred_label = -1 # corrupted
|
||||
|
||||
if answer == "less full":
|
||||
gt_label = 1
|
||||
elif answer == "more full":
|
||||
gt_label = 2
|
||||
elif answer == "equally full":
|
||||
gt_label = 3
|
||||
|
||||
return [gt_label, pred_label]
|
||||
|
||||
def check_accessibility_answer(self, prediction: str, answer: str) -> list:
|
||||
prediction = prediction.replace(".", "").lower()
|
||||
pred_label, gt_label = None, None
|
||||
if "more accessible" in prediction:
|
||||
pred_label = 1
|
||||
elif "less accessible" in prediction:
|
||||
pred_label = 2
|
||||
elif "equally accessible" in prediction:
|
||||
pred_label = 3
|
||||
else:
|
||||
pred_label = -1 # corrupted
|
||||
|
||||
if answer == "more accessible":
|
||||
gt_label = 1
|
||||
elif answer == "less accessible":
|
||||
gt_label = 2
|
||||
else:
|
||||
gt_label = 3
|
||||
|
||||
return [gt_label, pred_label]
|
||||
|
||||
def check_attitude_answer(self, prediction: str, answer: str) -> list:
|
||||
prediction = prediction.lower()
|
||||
answer = answer.lower()
|
||||
answer_map = {"a": "positive", "b": "neutral", "c": "negative"}
|
||||
prediction_token = (
|
||||
prediction.split("\n\n")[-1].split(":")[-1].split(".")[0].strip().lower()
|
||||
)
|
||||
gt_label, pred_label = None, None
|
||||
|
||||
if answer == "positive":
|
||||
gt_label = 1
|
||||
elif answer == "negative":
|
||||
gt_label = 2
|
||||
else:
|
||||
gt_label = 3
|
||||
|
||||
try:
|
||||
prediction = answer_map[prediction_token]
|
||||
if prediction == "positive":
|
||||
pred_label = 1
|
||||
elif prediction == "negative":
|
||||
pred_label = 2
|
||||
else:
|
||||
pred_label = 3
|
||||
|
||||
except:
|
||||
if "positive" in prediction_token and "negative" in prediction_token:
|
||||
pred_label = -1
|
||||
elif "positive" in prediction_token and "neutral" in prediction_token:
|
||||
pred_label = -1
|
||||
elif "neutral" in prediction_token and "negative" in prediction_token:
|
||||
pred_label = -1
|
||||
elif "positive" in prediction_token:
|
||||
pred_label = 1
|
||||
elif "negative" in prediction_token:
|
||||
pred_label = 2
|
||||
elif "neutral" in prediction_token:
|
||||
pred_label = 3
|
||||
else:
|
||||
pred_label = -1
|
||||
|
||||
return [gt_label, pred_label]
|
||||
Reference in New Issue
Block a user