Files
tweet-optimizer-v2/models.py
2025-10-19 19:17:58 -04:00

61 lines
2.3 KiB
Python

from pydantic import BaseModel, Field, validator
from typing import List
from constants import MIN_SCORE, MAX_SCORE
class CategoryEvaluation(BaseModel):
"""Pydantic model for a single category evaluation with reasoning."""
category: str = Field(description="The evaluation category name")
reasoning: str = Field(description="Explanation for the score")
score: int = Field(
description=f"Score for this category ({MIN_SCORE}-{MAX_SCORE})",
ge=MIN_SCORE,
le=MAX_SCORE
)
@validator('score')
def validate_score(cls, score):
"""Ensure score is within the valid range."""
if not isinstance(score, int) or score < MIN_SCORE or score > MAX_SCORE:
raise ValueError(f"Score {score} must be an integer between {MIN_SCORE} and {MAX_SCORE}")
return score
class EvaluationResult(BaseModel):
"""Pydantic model for tweet evaluation results."""
evaluations: List[CategoryEvaluation] = Field(
description="List of category evaluations with reasoning and scores"
)
@validator('evaluations')
def validate_evaluations(cls, evals):
"""Ensure we have at least one evaluation."""
if not evals or len(evals) < 1:
raise ValueError("Must have at least one category evaluation")
return evals
@property
def category_scores(self) -> List[int]:
"""Get list of scores for backwards compatibility."""
return [eval.score for eval in self.evaluations]
def total_score(self) -> float:
"""Calculate the total score across all categories."""
return sum(eval.score for eval in self.evaluations)
def average_score(self) -> float:
"""Calculate the average score across all categories."""
return self.total_score() / len(self.evaluations)
def __gt__(self, other):
"""Compare evaluation results based on total score."""
if not isinstance(other, EvaluationResult):
return NotImplemented
return self.total_score() > other.total_score()
def __eq__(self, other):
"""Check equality based on total score."""
if not isinstance(other, EvaluationResult):
return NotImplemented
return self.total_score() == other.total_score()