98 lines
3.2 KiB
Python
98 lines
3.2 KiB
Python
from modaic import PrecompiledAgent, PrecompiledConfig
|
|
from agent import SignatureGenerator
|
|
import dspy
|
|
|
|
|
|
class PromptToSignatureConfig(PrecompiledConfig):
|
|
lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
|
refine_lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
|
max_tokens: int = 4096
|
|
temperature: float = 0.7
|
|
max_attempts_to_refine: int = 5
|
|
|
|
|
|
class PromptToSignatureAgent(PrecompiledAgent):
|
|
config: PromptToSignatureConfig
|
|
|
|
def __init__(self, config: PromptToSignatureConfig, **kwargs):
|
|
super().__init__(config, **kwargs)
|
|
self.signature_generator = SignatureGenerator()
|
|
self.signature_refiner = dspy.Refine(
|
|
module=self.signature_generator,
|
|
N=config.max_attempts_to_refine,
|
|
reward_fn=self.validate_signature_with_feedback,
|
|
threshold=1.0,
|
|
)
|
|
|
|
|
|
lm = dspy.LM(
|
|
model=config.lm,
|
|
max_tokens=config.max_tokens,
|
|
temperature=config.temperature,
|
|
)
|
|
refine_lm = dspy.LM(
|
|
model=config.refine_lm,
|
|
max_tokens=config.max_tokens,
|
|
temperature=config.temperature,
|
|
)
|
|
|
|
self.signature_generator.set_lm(lm)
|
|
self.signature_refiner.set_lm(refine_lm)
|
|
|
|
def forward(self, prompt: str, as_dict: bool = False) -> dspy.Prediction: # returns dspy.Prediction object or dict
|
|
|
|
return (
|
|
self.signature_generator.generate_signature(prompt)
|
|
if as_dict
|
|
else self.signature_generator(prompt)
|
|
)
|
|
|
|
def generate_code(self, prediction: dspy.Prediction) -> str:
|
|
return self.signature_generator.generate_code(prediction)
|
|
|
|
def validate_signature_with_feedback(self, args, pred, feedback = None, satisfied_with_score = True):
|
|
"""Validation function for dspy.Refine that asks user for feedback"""
|
|
|
|
# Display the generated signature
|
|
print("\n" + "=" * 60)
|
|
print("🔍 Review Generated Signature")
|
|
print("=" * 60)
|
|
|
|
# Show the signature name and description
|
|
print(f"Signature Name: {pred.signature_name}")
|
|
print(f"Description: {pred.task_description}")
|
|
|
|
# Show the fields in a simple format
|
|
print(f"\nFields ({len(pred.signature_fields)}):")
|
|
for i, field in enumerate(pred.signature_fields, 1):
|
|
role_emoji = "📥" if field.role.value == "input" else "📤"
|
|
print(
|
|
f" {i}. {role_emoji} {field.name} ({field.type.value}) - {field.description}"
|
|
)
|
|
|
|
if satisfied_with_score:
|
|
print("✓ Signature approved!")
|
|
return 1.0
|
|
else:
|
|
if not feedback:
|
|
raise ValueError("Feedback is required if you are not satisfied with the signature!")
|
|
|
|
print(f"📝 Feedback recorded: {feedback}")
|
|
return dspy.Prediction(score=0.0, feedback=feedback)
|
|
|
|
|
|
agent = PromptToSignatureAgent(PromptToSignatureConfig())
|
|
|
|
|
|
def main():
|
|
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
|
|
|
|
result = agent(prompt="Generate jokes by prompt")
|
|
|
|
# Print the generated Python code
|
|
print(agent.generate_code(result))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|