From e26020f40c51c98a28cc81ced60f0e252642a1ce Mon Sep 17 00:00:00 2001 From: Farouk Adeleke Date: Wed, 29 Oct 2025 18:24:50 -0400 Subject: [PATCH] (no commit message) --- agent/__init__.py | 3 ++- agent/metrics.py | 31 ++++++++++++++++++++++++++ config.json | 4 +++- main.py | 57 +++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 91 insertions(+), 4 deletions(-) create mode 100644 agent/metrics.py diff --git a/agent/__init__.py b/agent/__init__.py index 98962fa..5ebe1aa 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -1,3 +1,4 @@ from .modules import SignatureGenerator +from .metrics import validate_signature_with_feedback -__all__ = ["SignatureGenerator"] +__all__ = ["SignatureGenerator", "validate_signature_with_feedback"] diff --git a/agent/metrics.py b/agent/metrics.py new file mode 100644 index 0000000..3682fea --- /dev/null +++ b/agent/metrics.py @@ -0,0 +1,31 @@ +import dspy + + +def validate_signature_with_feedback( + args, pred, feedback=None, satisfied_with_score=True +): + """Validation function for dspy.Refine that asks user for feedback""" + + # Display the generated signature + print("\n" + "=" * 60) + print("🔍 Review Generated Signature") + print("=" * 60) + + # Show the signature name and description + print(f"Signature Name: {pred.signature_name}") + print(f"Description: {pred.task_description}") + + # Show the fields in a simple format + print(f"\nFields ({len(pred.signature_fields)}):") + for i, field in enumerate(pred.signature_fields, 1): + role_emoji = "📥" if field.role.value == "input" else "📤" + print( + f" {i}. {role_emoji} {field.name} ({field.type.value}) - {field.description}" + ) + + if satisfied_with_score: + print("✓ Signature approved!") + return 1.0 + else: + print(f"📝 Feedback recorded: {feedback}") + return dspy.Prediction(score=0.0, feedback=feedback) diff --git a/config.json b/config.json index 5fe6a03..8dcb8b2 100644 --- a/config.json +++ b/config.json @@ -1,5 +1,7 @@ { "lm": "gemini/gemini-2.5-pro-preview-03-25", + "refine_lm": "gemini/gemini-2.5-pro-preview-03-25", "max_tokens": 4096, - "temperature": 0.7 + "temperature": 0.7, + "max_attempts_to_refine": 5 } \ No newline at end of file diff --git a/main.py b/main.py index 57b6bb1..b231773 100644 --- a/main.py +++ b/main.py @@ -5,8 +5,10 @@ import dspy class PromptToSignatureConfig(PrecompiledConfig): lm: str = "gemini/gemini-2.5-pro-preview-03-25" + refine_lm: str = "gemini/gemini-2.5-pro-preview-03-25" max_tokens: int = 4096 temperature: float = 0.7 + max_attempts_to_refine: int = 5 class PromptToSignatureAgent(PrecompiledAgent): @@ -15,23 +17,69 @@ class PromptToSignatureAgent(PrecompiledAgent): def __init__(self, config: PromptToSignatureConfig, **kwargs): super().__init__(config, **kwargs) self.signature_generator = SignatureGenerator() + self.signature_refiner = dspy.Refine( + module=self.signature_generator, + N=config.max_attempts_to_refine, + reward_fn=self.validate_signature_with_feedback, + threshold=1.0, + ) + lm = dspy.LM( model=config.lm, max_tokens=config.max_tokens, temperature=config.temperature, ) + refine_lm = dspy.LM( + model=config.refine_lm, + max_tokens=config.max_tokens, + temperature=config.temperature, + ) self.signature_generator.set_lm(lm) + self.signature_refiner.set_lm(refine_lm) - def forward(self, prompt: str, as_dict: bool = False) -> dspy.Prediction: - # returns dspy.Prediction object or dict + def forward(self, prompt: str, as_dict: bool = False) -> dspy.Prediction: # returns dspy.Prediction object or dict + return ( self.signature_generator.generate_signature(prompt) if as_dict else self.signature_generator(prompt) ) + def generate_code(self, prediction: dspy.Prediction) -> str: + return self.signature_generator.generate_code(prediction) + + def validate_signature_with_feedback(self, args, pred, feedback = None, satisfied_with_score = True): + """Validation function for dspy.Refine that asks user for feedback""" + + # Display the generated signature + print("\n" + "=" * 60) + print("🔍 Review Generated Signature") + print("=" * 60) + + # Show the signature name and description + print(f"Signature Name: {pred.signature_name}") + print(f"Description: {pred.task_description}") + + # Show the fields in a simple format + print(f"\nFields ({len(pred.signature_fields)}):") + for i, field in enumerate(pred.signature_fields, 1): + role_emoji = "📥" if field.role.value == "input" else "📤" + print( + f" {i}. {role_emoji} {field.name} ({field.type.value}) - {field.description}" + ) + + if satisfied_with_score: + print("✓ Signature approved!") + return 1.0 + else: + if not feedback: + raise ValueError("Feedback is required if you are not satisfied with the signature!") + + print(f"📝 Feedback recorded: {feedback}") + return dspy.Prediction(score=0.0, feedback=feedback) + agent = PromptToSignatureAgent(PromptToSignatureConfig()) @@ -39,6 +87,11 @@ agent = PromptToSignatureAgent(PromptToSignatureConfig()) def main(): agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True) + result = agent(prompt="Generate jokes by prompt") + + # Print the generated Python code + print(agent.generate_code(result)) + if __name__ == "__main__": main()