Complete Migration
This commit is contained in:
129
agent/modules.py
129
agent/modules.py
@@ -1,7 +1,7 @@
|
||||
import dspy
|
||||
from typing import List, Optional
|
||||
from models import EvaluationResult, CategoryEvaluation
|
||||
from constants import (
|
||||
from .models import EvaluationResult, CategoryEvaluation
|
||||
from .constants import (
|
||||
TWEET_MAX_LENGTH,
|
||||
TWEET_TRUNCATION_SUFFIX,
|
||||
DEFAULT_SCORE,
|
||||
@@ -10,87 +10,115 @@ from constants import (
|
||||
ERROR_GENERATION,
|
||||
ERROR_EVALUATION,
|
||||
MIN_SCORE,
|
||||
MAX_SCORE
|
||||
MAX_SCORE,
|
||||
)
|
||||
from helpers import format_evaluation_for_generator, truncate_tweet
|
||||
from .helpers import format_evaluation_for_generator, truncate_tweet
|
||||
|
||||
|
||||
class TweetGenerator(dspy.Signature):
|
||||
"""Generate or improve a tweet based on input text and detailed evaluation feedback with reasoning."""
|
||||
|
||||
|
||||
input_text: str = dspy.InputField(desc="Original text or current tweet to improve")
|
||||
current_tweet: str = dspy.InputField(desc="Current best tweet version (empty for first generation)")
|
||||
previous_evaluation: str = dspy.InputField(desc="Previous evaluation with category-by-category reasoning and scores (empty for first generation)")
|
||||
improved_tweet: str = dspy.OutputField(desc=f"Generated or improved tweet text (max {TWEET_MAX_LENGTH} characters)")
|
||||
current_tweet: str = dspy.InputField(
|
||||
desc="Current best tweet version (empty for first generation)"
|
||||
)
|
||||
previous_evaluation: str = dspy.InputField(
|
||||
desc="Previous evaluation with category-by-category reasoning and scores (empty for first generation)"
|
||||
)
|
||||
improved_tweet: str = dspy.OutputField(
|
||||
desc=f"Generated or improved tweet text (max {TWEET_MAX_LENGTH} characters)"
|
||||
)
|
||||
|
||||
|
||||
class TweetEvaluator(dspy.Signature):
|
||||
"""Evaluate a tweet across multiple custom categories. For each category, provide detailed reasoning explaining the score, then assign a score. Ensure the tweet maintains the same meaning as the original text."""
|
||||
|
||||
original_text: str = dspy.InputField(desc="Original input text that started the optimization")
|
||||
current_best_tweet: str = dspy.InputField(desc="Current best tweet version for comparison (empty for first evaluation)")
|
||||
|
||||
original_text: str = dspy.InputField(
|
||||
desc="Original input text that started the optimization"
|
||||
)
|
||||
current_best_tweet: str = dspy.InputField(
|
||||
desc="Current best tweet version for comparison (empty for first evaluation)"
|
||||
)
|
||||
tweet_text: str = dspy.InputField(desc="Tweet text to evaluate")
|
||||
categories: str = dspy.InputField(desc="Comma-separated list of evaluation category descriptions")
|
||||
categories: str = dspy.InputField(
|
||||
desc="Comma-separated list of evaluation category descriptions"
|
||||
)
|
||||
evaluations: List[CategoryEvaluation] = dspy.OutputField(
|
||||
desc=f"List of evaluations with category name, detailed reasoning, and score ({MIN_SCORE}-{MAX_SCORE}) for each category. Ensure the tweet conveys the same meaning as the original text."
|
||||
)
|
||||
|
||||
|
||||
class TweetGeneratorModule(dspy.Module):
|
||||
"""DSPy module for generating and improving tweets."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.generate = dspy.ChainOfThought(TweetGenerator)
|
||||
|
||||
def forward(self, input_text: str, current_tweet: str = "", previous_evaluation: Optional[EvaluationResult] = None) -> str:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_text: str,
|
||||
current_tweet: str = "",
|
||||
previous_evaluation: Optional[EvaluationResult] = None,
|
||||
) -> str:
|
||||
"""Generate or improve a tweet."""
|
||||
try:
|
||||
# Format previous evaluation as text
|
||||
eval_text = format_evaluation_for_generator(previous_evaluation)
|
||||
|
||||
|
||||
result = self.generate(
|
||||
input_text=input_text,
|
||||
current_tweet=current_tweet,
|
||||
previous_evaluation=eval_text
|
||||
previous_evaluation=eval_text,
|
||||
)
|
||||
|
||||
|
||||
# Ensure tweet doesn't exceed character limit
|
||||
tweet = truncate_tweet(result.improved_tweet, TWEET_MAX_LENGTH, TWEET_TRUNCATION_SUFFIX)
|
||||
|
||||
tweet = truncate_tweet(
|
||||
result.improved_tweet, TWEET_MAX_LENGTH, TWEET_TRUNCATION_SUFFIX
|
||||
)
|
||||
|
||||
return tweet
|
||||
except Exception as e:
|
||||
raise Exception(f"{ERROR_GENERATION}: {str(e)}")
|
||||
|
||||
|
||||
class TweetEvaluatorModule(dspy.Module):
|
||||
"""DSPy module for evaluating tweets across custom categories."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.evaluate = dspy.ChainOfThought(TweetEvaluator)
|
||||
|
||||
def forward(self, tweet_text: str, categories: List[str], original_text: str = "", current_best_tweet: str = "") -> EvaluationResult:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tweet_text: str,
|
||||
categories: List[str],
|
||||
original_text: str = "",
|
||||
current_best_tweet: str = "",
|
||||
) -> EvaluationResult:
|
||||
"""Evaluate a tweet across specified categories."""
|
||||
try:
|
||||
# Join categories into comma-separated string
|
||||
categories_str = ", ".join(categories)
|
||||
|
||||
|
||||
result = self.evaluate(
|
||||
original_text=original_text,
|
||||
current_best_tweet=current_best_tweet,
|
||||
tweet_text=tweet_text,
|
||||
categories=categories_str
|
||||
categories=categories_str,
|
||||
)
|
||||
|
||||
|
||||
# Extract and validate evaluations
|
||||
evaluations = result.evaluations
|
||||
|
||||
|
||||
# Ensure we have the right number of evaluations
|
||||
if len(evaluations) != len(categories):
|
||||
# Create default evaluations if mismatch
|
||||
evaluations = [
|
||||
CategoryEvaluation(
|
||||
category=cat,
|
||||
reasoning=ERROR_PARSING,
|
||||
score=DEFAULT_SCORE
|
||||
) for cat in categories
|
||||
category=cat, reasoning=ERROR_PARSING, score=DEFAULT_SCORE
|
||||
)
|
||||
for cat in categories
|
||||
]
|
||||
else:
|
||||
# Validate each evaluation
|
||||
@@ -99,22 +127,32 @@ class TweetEvaluatorModule(dspy.Module):
|
||||
try:
|
||||
# Ensure score is valid
|
||||
score = max(MIN_SCORE, min(MAX_SCORE, int(eval.score)))
|
||||
validated_evals.append(CategoryEvaluation(
|
||||
category=categories[i] if i < len(categories) else eval.category,
|
||||
reasoning=eval.reasoning if eval.reasoning else "No reasoning provided",
|
||||
score=score
|
||||
))
|
||||
validated_evals.append(
|
||||
CategoryEvaluation(
|
||||
category=categories[i]
|
||||
if i < len(categories)
|
||||
else eval.category,
|
||||
reasoning=eval.reasoning
|
||||
if eval.reasoning
|
||||
else "No reasoning provided",
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
validated_evals.append(CategoryEvaluation(
|
||||
category=categories[i] if i < len(categories) else "Unknown",
|
||||
reasoning=ERROR_VALIDATION,
|
||||
score=DEFAULT_SCORE
|
||||
))
|
||||
validated_evals.append(
|
||||
CategoryEvaluation(
|
||||
category=categories[i]
|
||||
if i < len(categories)
|
||||
else "Unknown",
|
||||
reasoning=ERROR_VALIDATION,
|
||||
score=DEFAULT_SCORE,
|
||||
)
|
||||
)
|
||||
evaluations = validated_evals
|
||||
|
||||
|
||||
# Create validated result
|
||||
validated_result = EvaluationResult(evaluations=evaluations)
|
||||
|
||||
|
||||
return validated_result
|
||||
except Exception as e:
|
||||
# Return default evaluations on error
|
||||
@@ -122,7 +160,8 @@ class TweetEvaluatorModule(dspy.Module):
|
||||
CategoryEvaluation(
|
||||
category=cat,
|
||||
reasoning=f"{ERROR_EVALUATION}: {str(e)}",
|
||||
score=DEFAULT_SCORE
|
||||
) for cat in categories
|
||||
score=DEFAULT_SCORE,
|
||||
)
|
||||
for cat in categories
|
||||
]
|
||||
return EvaluationResult(evaluations=default_evals)
|
||||
|
||||
Reference in New Issue
Block a user