(no commit message)
This commit is contained in:
18
main.py
18
main.py
@@ -24,7 +24,6 @@ class PromptToSignatureAgent(PrecompiledAgent):
|
||||
threshold=1.0,
|
||||
)
|
||||
|
||||
|
||||
lm = dspy.LM(
|
||||
model=config.lm,
|
||||
max_tokens=config.max_tokens,
|
||||
@@ -39,8 +38,9 @@ class PromptToSignatureAgent(PrecompiledAgent):
|
||||
self.signature_generator.set_lm(lm)
|
||||
self.signature_refiner.set_lm(refine_lm)
|
||||
|
||||
def forward(self, prompt: str, as_dict: bool = False) -> dspy.Prediction: # returns dspy.Prediction object or dict
|
||||
|
||||
def forward(
|
||||
self, prompt: str, as_dict: bool = False
|
||||
) -> dspy.Prediction: # returns dspy.Prediction object or dict
|
||||
return (
|
||||
self.signature_generator.generate_signature(prompt)
|
||||
if as_dict
|
||||
@@ -50,7 +50,9 @@ class PromptToSignatureAgent(PrecompiledAgent):
|
||||
def generate_code(self, prediction: dspy.Prediction) -> str:
|
||||
return self.signature_generator.generate_code(prediction)
|
||||
|
||||
def validate_signature_with_feedback(self, args, pred, feedback = None, satisfied_with_score = True):
|
||||
def validate_signature_with_feedback(
|
||||
self, args, pred, feedback=None, satisfied_with_score=True
|
||||
):
|
||||
"""Validation function for dspy.Refine that asks user for feedback"""
|
||||
|
||||
# Display the generated signature
|
||||
@@ -75,7 +77,9 @@ class PromptToSignatureAgent(PrecompiledAgent):
|
||||
return 1.0
|
||||
else:
|
||||
if not feedback:
|
||||
raise ValueError("Feedback is required if you are not satisfied with the signature!")
|
||||
raise ValueError(
|
||||
"Feedback is required if you are not satisfied with the signature!"
|
||||
)
|
||||
|
||||
print(f"📝 Feedback recorded: {feedback}")
|
||||
return dspy.Prediction(score=0.0, feedback=feedback)
|
||||
@@ -88,8 +92,8 @@ def main():
|
||||
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
|
||||
|
||||
result = agent(prompt="Generate jokes by prompt")
|
||||
|
||||
# Print the generated Python code
|
||||
print("RESULT: ", result)
|
||||
print("GENERATED CODE:\n\n")
|
||||
print(agent.generate_code(result))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user