146 lines
4.8 KiB
Python
146 lines
4.8 KiB
Python
from modaic import PrecompiledAgent, PrecompiledConfig
|
|
from .modules import TweetGeneratorModule, TweetEvaluatorModule
|
|
from .models import EvaluationResult
|
|
from .hill_climbing import HillClimbingOptimizer
|
|
from typing import Optional, List
|
|
from .utils import get_dspy_lm
|
|
from .constants import DEFAULT_CATEGORIES, DEFAULT_ITERATIONS, DEFAULT_PATIENCE
|
|
|
|
|
|
class TweetOptimizerConfig(PrecompiledConfig):
|
|
lm: str = "openrouter/google/gemini-2.5-flash"
|
|
eval_lm: str = "openrouter/openai/gpt-5"
|
|
categories: List[str] = DEFAULT_CATEGORIES
|
|
max_iterations: int = DEFAULT_ITERATIONS
|
|
patience: int = DEFAULT_PATIENCE
|
|
|
|
|
|
class TweetOptimizerAgent(PrecompiledAgent):
|
|
config: TweetOptimizerConfig
|
|
|
|
current_tweet: str = ""
|
|
previous_evaluation: Optional[EvaluationResult] = None
|
|
|
|
def __init__(self, config: TweetOptimizerConfig, **kwargs):
|
|
super().__init__(config, **kwargs)
|
|
self.tweet_generator = TweetGeneratorModule()
|
|
self.tweet_evaluator = TweetEvaluatorModule()
|
|
|
|
# set up optimizer
|
|
self.optimizer = HillClimbingOptimizer(
|
|
generator=self.tweet_generator,
|
|
evaluator=self.tweet_evaluator,
|
|
categories=config.categories,
|
|
max_iterations=config.max_iterations,
|
|
patience=config.patience
|
|
)
|
|
|
|
self.lm = config.lm
|
|
self.eval_lm = config.eval_lm
|
|
|
|
# initialize DSPy with the specified model
|
|
self.tweet_generator.set_lm(get_dspy_lm(config.lm))
|
|
self.tweet_evaluator.set_lm(get_dspy_lm(config.eval_lm))
|
|
|
|
def forward(
|
|
self,
|
|
input_text: str,
|
|
iterations: Optional[int] = None,
|
|
patience: Optional[int] = None
|
|
) -> str:
|
|
"""Run full optimization process."""
|
|
max_iterations = iterations or self.config.max_iterations
|
|
patience_limit = patience or self.config.patience
|
|
|
|
results = {
|
|
'initial_text': input_text,
|
|
'final_tweet': '',
|
|
'best_score': 0.0,
|
|
'iterations_run': 0,
|
|
'early_stopped': False,
|
|
'scores_history': [],
|
|
'improvement_count': 0
|
|
}
|
|
|
|
best_tweet = ""
|
|
best_score = 0.0
|
|
|
|
for iteration, (current_tweet, scores, is_improvement, patience_counter, _, _) in enumerate(
|
|
self.optimizer.optimize(input_text)
|
|
):
|
|
iteration_num = iteration + 1
|
|
results['iterations_run'] = iteration_num
|
|
results['scores_history'].append(scores)
|
|
|
|
if is_improvement:
|
|
best_tweet = current_tweet
|
|
best_score = sum(scores.category_scores) / len(scores.category_scores)
|
|
results['improvement_count'] += 1
|
|
|
|
# check for early stopping
|
|
if patience_counter >= patience_limit:
|
|
results['early_stopped'] = True
|
|
break
|
|
|
|
# stop at max iterations
|
|
if iteration_num >= max_iterations:
|
|
break
|
|
|
|
results.update({
|
|
'final_tweet': best_tweet,
|
|
'best_score': best_score
|
|
})
|
|
|
|
self.reset()
|
|
|
|
return results
|
|
|
|
def reset(self):
|
|
self.current_tweet = ""
|
|
self.previous_evaluation = None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# create agent with default config
|
|
config = TweetOptimizerConfig()
|
|
tweet_optimizer = TweetOptimizerAgent(config)
|
|
import os
|
|
|
|
# set up test environment (replace with real API key for actual usage)
|
|
if not os.getenv("OPENROUTER_API_KEY"):
|
|
raise ValueError("OPENROUTER_API_KEY environment variable is not set")
|
|
|
|
# full optimization process
|
|
print("\n=== Full Optimization Process ===")
|
|
try:
|
|
results = tweet_optimizer(
|
|
input_text="Anthropic added a new OSS model on HuggingFace.",
|
|
iterations=10, # Reduced for testing
|
|
patience=8
|
|
)
|
|
print(f"Initial text: {results['initial_text']}")
|
|
print(f"Final tweet: {results['final_tweet']}")
|
|
print(f"Best score: {results['best_score']:.2f}")
|
|
print(f"Iterations run: {results['iterations_run']}")
|
|
print(f"Improvements found: {results['improvement_count']}")
|
|
print(f"Early stopped: {results['early_stopped']}")
|
|
except Exception as e:
|
|
print(f"Error in optimization: {e}")
|
|
|
|
# push to hub
|
|
print("\n=== Push to Hub ===")
|
|
try:
|
|
tweet_optimizer.push_to_hub(
|
|
"farouk1/tweet-optimizer-v2",
|
|
commit_message="Complete Migration",
|
|
with_code=True
|
|
)
|
|
print("Successfully pushed to hub!")
|
|
except Exception as e:
|
|
print(f"Error pushing to hub: {e}")
|
|
|
|
print("\n=== Agent Configuration ===")
|
|
print(f"Model: {config.lm}")
|
|
print(f"Categories: {config.categories}")
|
|
print(f"Max iterations: {config.max_iterations}")
|
|
print(f"Patience: {config.patience}") |