(no commit message)
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
from .modules import signature_generator
|
from .modules import signature_generator
|
||||||
|
from .index import PromptToSignatureAgent, PromptToSignatureConfig
|
||||||
|
|
||||||
__all__ = ["signature_generator"]
|
__all__ = ["signature_generator", "PromptToSignatureAgent", "PromptToSignatureConfig"]
|
||||||
|
|||||||
96
agent/index.py
Normal file
96
agent/index.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import dspy
|
||||||
|
from .modules import signature_generator
|
||||||
|
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||||
|
|
||||||
|
class PromptToSignatureConfig(PrecompiledConfig):
|
||||||
|
lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
||||||
|
refine_lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
||||||
|
max_tokens: int = 4096
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_attempts_to_refine: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
class PromptToSignatureAgent(PrecompiledAgent):
|
||||||
|
config: PromptToSignatureConfig
|
||||||
|
|
||||||
|
def __init__(self, config: PromptToSignatureConfig, **kwargs):
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
self.signature_generator = signature_generator
|
||||||
|
self.signature_refiner = dspy.Refine(
|
||||||
|
module=self.signature_generator,
|
||||||
|
N=config.max_attempts_to_refine,
|
||||||
|
reward_fn=PromptToSignatureAgent.validate_signature_with_feedback,
|
||||||
|
threshold=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
lm = dspy.LM(
|
||||||
|
model=config.lm,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
temperature=config.temperature,
|
||||||
|
)
|
||||||
|
refine_lm = dspy.LM(
|
||||||
|
model=config.refine_lm,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
temperature=config.temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.signature_generator.set_lm(lm)
|
||||||
|
self.signature_refiner.set_lm(refine_lm)
|
||||||
|
|
||||||
|
def forward(self, prompt: str, refine: bool = False) -> dspy.Prediction:
|
||||||
|
if not prompt:
|
||||||
|
raise ValueError("Prompt is required!!")
|
||||||
|
|
||||||
|
if refine:
|
||||||
|
try:
|
||||||
|
result = self.signature_refiner(prompt=prompt)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Refinement failed: {e}")
|
||||||
|
print("💡 Try adjusting your prompt or increasing max attempts")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
result = self.signature_generator(prompt)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_code(self, prediction: dspy.Prediction) -> str:
|
||||||
|
return self.signature_generator.generate_code(prediction)
|
||||||
|
|
||||||
|
@staticmethod # attached metric for refinement
|
||||||
|
def validate_signature_with_feedback(args, pred):
|
||||||
|
"""Validation function for dspy.Refine that asks user for feedback"""
|
||||||
|
|
||||||
|
# display the generated signature
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🔍 Review Generated Signature")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# show the signature name and description
|
||||||
|
print(f"Signature Name: {pred.signature_name}")
|
||||||
|
print(f"Description: {pred.task_description}")
|
||||||
|
|
||||||
|
# show the fields in a simple format
|
||||||
|
print(f"\nFields ({len(pred.signature_fields)}):")
|
||||||
|
for i, field in enumerate(pred.signature_fields, 1):
|
||||||
|
role_emoji = "📥" if field.role.value == "input" else "📤"
|
||||||
|
print(
|
||||||
|
f" {i}. {role_emoji} {field.name} ({field.type.value}) - {field.description}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ask for user approval (in an app, this would be a state variable)
|
||||||
|
is_satisfied = input("Are you satisfied with this signature? (y/n): ")
|
||||||
|
is_satisfied = is_satisfied.lower() == "y"
|
||||||
|
|
||||||
|
if is_satisfied:
|
||||||
|
print("✓ Signature approved!")
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
# ask for feedback (in an app, this would be a state variable)
|
||||||
|
feedback = input("Please provide feedback for improvement: ")
|
||||||
|
if not feedback:
|
||||||
|
raise ValueError(
|
||||||
|
"Feedback is required if you are not satisfied with the signature!"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"📝 Feedback recorded: {feedback}")
|
||||||
|
return dspy.Prediction(score=0.0, feedback=feedback)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"AutoConfig": "main.PromptToSignatureConfig",
|
"AutoConfig": "agent.index.PromptToSignatureConfig",
|
||||||
"AutoAgent": "main.PromptToSignatureAgent"
|
"AutoAgent": "agent.index.PromptToSignatureAgent"
|
||||||
}
|
}
|
||||||
104
main.py
104
main.py
@@ -1,101 +1,4 @@
|
|||||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
from agent import PromptToSignatureAgent, PromptToSignatureConfig
|
||||||
from agent import signature_generator
|
|
||||||
import dspy
|
|
||||||
|
|
||||||
|
|
||||||
class PromptToSignatureConfig(PrecompiledConfig):
|
|
||||||
lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
|
||||||
refine_lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
|
||||||
max_tokens: int = 4096
|
|
||||||
temperature: float = 0.7
|
|
||||||
max_attempts_to_refine: int = 5
|
|
||||||
|
|
||||||
|
|
||||||
class PromptToSignatureAgent(PrecompiledAgent):
|
|
||||||
config: PromptToSignatureConfig
|
|
||||||
|
|
||||||
def __init__(self, config: PromptToSignatureConfig, **kwargs):
|
|
||||||
super().__init__(config, **kwargs)
|
|
||||||
self.signature_generator = signature_generator
|
|
||||||
self.signature_refiner = dspy.Refine(
|
|
||||||
module=self.signature_generator,
|
|
||||||
N=config.max_attempts_to_refine,
|
|
||||||
reward_fn=PromptToSignatureAgent.validate_signature_with_feedback,
|
|
||||||
threshold=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
lm = dspy.LM(
|
|
||||||
model=config.lm,
|
|
||||||
max_tokens=config.max_tokens,
|
|
||||||
temperature=config.temperature,
|
|
||||||
)
|
|
||||||
refine_lm = dspy.LM(
|
|
||||||
model=config.refine_lm,
|
|
||||||
max_tokens=config.max_tokens,
|
|
||||||
temperature=config.temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.signature_generator.set_lm(lm)
|
|
||||||
self.signature_refiner.set_lm(refine_lm)
|
|
||||||
|
|
||||||
def forward(self, prompt: str, refine: bool = False) -> dspy.Prediction:
|
|
||||||
if not prompt:
|
|
||||||
raise ValueError("Prompt is required!!")
|
|
||||||
|
|
||||||
if refine:
|
|
||||||
try:
|
|
||||||
result = self.signature_refiner(prompt=prompt)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Refinement failed: {e}")
|
|
||||||
print("💡 Try adjusting your prompt or increasing max attempts")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
result = self.signature_generator(prompt)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def generate_code(self, prediction: dspy.Prediction) -> str:
|
|
||||||
return self.signature_generator.generate_code(prediction)
|
|
||||||
|
|
||||||
@staticmethod # attached metric for refinement
|
|
||||||
def validate_signature_with_feedback(args, pred):
|
|
||||||
"""Validation function for dspy.Refine that asks user for feedback"""
|
|
||||||
|
|
||||||
# display the generated signature
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🔍 Review Generated Signature")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# show the signature name and description
|
|
||||||
print(f"Signature Name: {pred.signature_name}")
|
|
||||||
print(f"Description: {pred.task_description}")
|
|
||||||
|
|
||||||
# show the fields in a simple format
|
|
||||||
print(f"\nFields ({len(pred.signature_fields)}):")
|
|
||||||
for i, field in enumerate(pred.signature_fields, 1):
|
|
||||||
role_emoji = "📥" if field.role.value == "input" else "📤"
|
|
||||||
print(
|
|
||||||
f" {i}. {role_emoji} {field.name} ({field.type.value}) - {field.description}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ask for user approval (in an app, this would be a state variable)
|
|
||||||
is_satisfied = input("Are you satisfied with this signature? (y/n): ")
|
|
||||||
is_satisfied = is_satisfied.lower() == "y"
|
|
||||||
|
|
||||||
if is_satisfied:
|
|
||||||
print("✓ Signature approved!")
|
|
||||||
return 1.0
|
|
||||||
else:
|
|
||||||
# ask for feedback (in an app, this would be a state variable)
|
|
||||||
feedback = input("Please provide feedback for improvement: ")
|
|
||||||
if not feedback:
|
|
||||||
raise ValueError(
|
|
||||||
"Feedback is required if you are not satisfied with the signature!"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"📝 Feedback recorded: {feedback}")
|
|
||||||
return dspy.Prediction(score=0.0, feedback=feedback)
|
|
||||||
|
|
||||||
|
|
||||||
agent = PromptToSignatureAgent(PromptToSignatureConfig())
|
agent = PromptToSignatureAgent(PromptToSignatureConfig())
|
||||||
|
|
||||||
@@ -167,11 +70,6 @@ CR_PROMPT = """ You are Charlotte, an advanced knowledge graph connection reason
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# try refine
|
|
||||||
#refined_result = agent(
|
|
||||||
# prompt=CR_PROMPT,
|
|
||||||
#)
|
|
||||||
|
|
||||||
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
|
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user