(no commit message)

This commit is contained in:
2025-10-29 18:24:50 -04:00
parent 6475810170
commit e26020f40c
4 changed files with 91 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
from .modules import SignatureGenerator from .modules import SignatureGenerator
from .metrics import validate_signature_with_feedback
__all__ = ["SignatureGenerator"] __all__ = ["SignatureGenerator", "validate_signature_with_feedback"]

31
agent/metrics.py Normal file
View File

@@ -0,0 +1,31 @@
import dspy
def validate_signature_with_feedback(
args, pred, feedback=None, satisfied_with_score=True
):
"""Validation function for dspy.Refine that asks user for feedback"""
# Display the generated signature
print("\n" + "=" * 60)
print("🔍 Review Generated Signature")
print("=" * 60)
# Show the signature name and description
print(f"Signature Name: {pred.signature_name}")
print(f"Description: {pred.task_description}")
# Show the fields in a simple format
print(f"\nFields ({len(pred.signature_fields)}):")
for i, field in enumerate(pred.signature_fields, 1):
role_emoji = "📥" if field.role.value == "input" else "📤"
print(
f" {i}. {role_emoji} {field.name} ({field.type.value}) - {field.description}"
)
if satisfied_with_score:
print("✓ Signature approved!")
return 1.0
else:
print(f"📝 Feedback recorded: {feedback}")
return dspy.Prediction(score=0.0, feedback=feedback)

View File

@@ -1,5 +1,7 @@
{ {
"lm": "gemini/gemini-2.5-pro-preview-03-25", "lm": "gemini/gemini-2.5-pro-preview-03-25",
"refine_lm": "gemini/gemini-2.5-pro-preview-03-25",
"max_tokens": 4096, "max_tokens": 4096,
"temperature": 0.7 "temperature": 0.7,
"max_attempts_to_refine": 5
} }

57
main.py
View File

@@ -5,8 +5,10 @@ import dspy
class PromptToSignatureConfig(PrecompiledConfig): class PromptToSignatureConfig(PrecompiledConfig):
lm: str = "gemini/gemini-2.5-pro-preview-03-25" lm: str = "gemini/gemini-2.5-pro-preview-03-25"
refine_lm: str = "gemini/gemini-2.5-pro-preview-03-25"
max_tokens: int = 4096 max_tokens: int = 4096
temperature: float = 0.7 temperature: float = 0.7
max_attempts_to_refine: int = 5
class PromptToSignatureAgent(PrecompiledAgent): class PromptToSignatureAgent(PrecompiledAgent):
@@ -15,23 +17,69 @@ class PromptToSignatureAgent(PrecompiledAgent):
def __init__(self, config: PromptToSignatureConfig, **kwargs): def __init__(self, config: PromptToSignatureConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.signature_generator = SignatureGenerator() self.signature_generator = SignatureGenerator()
self.signature_refiner = dspy.Refine(
module=self.signature_generator,
N=config.max_attempts_to_refine,
reward_fn=self.validate_signature_with_feedback,
threshold=1.0,
)
lm = dspy.LM( lm = dspy.LM(
model=config.lm, model=config.lm,
max_tokens=config.max_tokens, max_tokens=config.max_tokens,
temperature=config.temperature, temperature=config.temperature,
) )
refine_lm = dspy.LM(
model=config.refine_lm,
max_tokens=config.max_tokens,
temperature=config.temperature,
)
self.signature_generator.set_lm(lm) self.signature_generator.set_lm(lm)
self.signature_refiner.set_lm(refine_lm)
def forward(self, prompt: str, as_dict: bool = False) -> dspy.Prediction: # returns dspy.Prediction object or dict
def forward(self, prompt: str, as_dict: bool = False) -> dspy.Prediction:
# returns dspy.Prediction object or dict
return ( return (
self.signature_generator.generate_signature(prompt) self.signature_generator.generate_signature(prompt)
if as_dict if as_dict
else self.signature_generator(prompt) else self.signature_generator(prompt)
) )
def generate_code(self, prediction: dspy.Prediction) -> str:
return self.signature_generator.generate_code(prediction)
def validate_signature_with_feedback(self, args, pred, feedback = None, satisfied_with_score = True):
"""Validation function for dspy.Refine that asks user for feedback"""
# Display the generated signature
print("\n" + "=" * 60)
print("🔍 Review Generated Signature")
print("=" * 60)
# Show the signature name and description
print(f"Signature Name: {pred.signature_name}")
print(f"Description: {pred.task_description}")
# Show the fields in a simple format
print(f"\nFields ({len(pred.signature_fields)}):")
for i, field in enumerate(pred.signature_fields, 1):
role_emoji = "📥" if field.role.value == "input" else "📤"
print(
f" {i}. {role_emoji} {field.name} ({field.type.value}) - {field.description}"
)
if satisfied_with_score:
print("✓ Signature approved!")
return 1.0
else:
if not feedback:
raise ValueError("Feedback is required if you are not satisfied with the signature!")
print(f"📝 Feedback recorded: {feedback}")
return dspy.Prediction(score=0.0, feedback=feedback)
agent = PromptToSignatureAgent(PromptToSignatureConfig()) agent = PromptToSignatureAgent(PromptToSignatureConfig())
@@ -39,6 +87,11 @@ agent = PromptToSignatureAgent(PromptToSignatureConfig())
def main(): def main():
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True) agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
result = agent(prompt="Generate jokes by prompt")
# Print the generated Python code
print(agent.generate_code(result))
if __name__ == "__main__": if __name__ == "__main__":
main() main()