Files
tweet-optimizer-v2/agent/models.py
2025-10-19 20:03:29 -04:00

75 lines
2.9 KiB
Python

from pydantic import BaseModel, Field, validator
from typing import List
from .constants import MIN_SCORE, MAX_SCORE
class CategoryEvaluation(BaseModel):
"""Pydantic model for a single category evaluation with reasoning."""
category: str = Field(description="The evaluation category name")
reasoning: str = Field(description="Explanation for the score")
score: int = Field(
description=f"Score for this category ({MIN_SCORE}-{MAX_SCORE})",
ge=MIN_SCORE,
le=MAX_SCORE,
)
@validator("score")
def validate_score(cls, score):
"""Ensure score is within the valid range."""
if not isinstance(score, int) or score < MIN_SCORE or score > MAX_SCORE:
raise ValueError(
f"Score {score} must be an integer between {MIN_SCORE} and {MAX_SCORE}"
)
return score
class FinalResult(BaseModel):
"""Pydantic model for final tweet optimization result."""
initial_text: str = Field(description="The initial tweet text")
final_tweet: str = Field(description="The final optimized tweet")
best_score: float = Field(description="The best score for the final tweet")
iterations_run: int = Field(description="The number of iterations run")
early_stopped: bool = Field(description="Whether the optimization early stopped")
scores_history: List[List[int]] = Field(description="The history of scores")
improvement_count: int = Field(description="The number of improvements found")
class EvaluationResult(BaseModel):
"""Pydantic model for tweet evaluation results."""
evaluations: List[CategoryEvaluation] = Field(
description="List of category evaluations with reasoning and scores"
)
@validator("evaluations")
def validate_evaluations(cls, evals):
"""Ensure we have at least one evaluation."""
if not evals or len(evals) < 1:
raise ValueError("Must have at least one category evaluation")
return evals
@property
def category_scores(self) -> List[int]:
"""Get list of scores for backwards compatibility."""
return [eval.score for eval in self.evaluations]
def total_score(self) -> float:
"""Calculate the total score across all categories."""
return sum(eval.score for eval in self.evaluations)
def average_score(self) -> float:
"""Calculate the average score across all categories."""
return self.total_score() / len(self.evaluations)
def __gt__(self, other):
"""Compare evaluation results based on total score."""
if not isinstance(other, EvaluationResult):
return NotImplemented
return self.total_score() > other.total_score()
def __eq__(self, other):
"""Check equality based on total score."""
if not isinstance(other, EvaluationResult):
return NotImplemented
return self.total_score() == other.total_score()