import dspy from .modules import signature_generator from modaic import PrecompiledAgent, PrecompiledConfig class PromptToSignatureConfig(PrecompiledConfig): lm: str = "gemini/gemini-2.5-pro-preview-03-25" refine_lm: str = "gemini/gemini-2.5-pro-preview-03-25" max_tokens: int = 4096 temperature: float = 0.7 max_attempts_to_refine: int = 5 class PromptToSignatureAgent(PrecompiledAgent): config: PromptToSignatureConfig def __init__(self, config: PromptToSignatureConfig, **kwargs): super().__init__(config, **kwargs) self.signature_generator = signature_generator self.signature_refiner = dspy.Refine( module=self.signature_generator, N=config.max_attempts_to_refine, reward_fn=PromptToSignatureAgent.validate_signature_with_feedback, threshold=1.0, ) lm = dspy.LM( model=config.lm, max_tokens=config.max_tokens, temperature=config.temperature, ) refine_lm = dspy.LM( model=config.refine_lm, max_tokens=config.max_tokens, temperature=config.temperature, ) self.signature_generator.set_lm(lm) self.signature_refiner.set_lm(refine_lm) def forward(self, prompt: str, refine: bool = False) -> dspy.Prediction: if not prompt: raise ValueError("Prompt is required!!") if refine: try: result = self.signature_refiner(prompt=prompt) except Exception as e: print(f"Refinement failed: {e}") print("💡 Try adjusting your prompt or increasing max attempts") return None else: result = self.signature_generator(prompt) return result def generate_code(self, prediction: dspy.Prediction) -> str: return self.signature_generator.generate_code(prediction) @staticmethod # attached metric for refinement def validate_signature_with_feedback(args, pred): """Validation function for dspy.Refine that asks user for feedback""" # display the generated signature print("\n" + "=" * 60) print("🔍 Review Generated Signature") print("=" * 60) # show the signature name and description print(f"Signature Name: {pred.signature_name}") print(f"Description: {pred.task_description}") # show the fields in a simple format print(f"\nFields ({len(pred.signature_fields)}):") for i, field in enumerate(pred.signature_fields, 1): role_emoji = "📥" if field.role.value == "input" else "📤" print( f" {i}. {role_emoji} {field.name} ({field.type.value}) - {field.description}" ) # ask for user approval (in an app, this would be a state variable) is_satisfied = input("Are you satisfied with this signature? (y/n): ") is_satisfied = is_satisfied.lower() == "y" if is_satisfied: print("✓ Signature approved!") return 1.0 else: # ask for feedback (in an app, this would be a state variable) feedback = input("Please provide feedback for improvement: ") if not feedback: raise ValueError( "Feedback is required if you are not satisfied with the signature!" ) print(f"📝 Feedback recorded: {feedback}") return dspy.Prediction(score=0.0, feedback=feedback)