(no commit message)
This commit is contained in:
29
main.py
29
main.py
@@ -1,23 +1,44 @@
|
||||
from modaic import PrecompiledAgent, PrecompiledConfig
|
||||
from agent import SignatureGenerator
|
||||
import dspy
|
||||
|
||||
|
||||
class PromptToSignatureConfig(PrecompiledConfig):
|
||||
lm: str = "gpt-4o"
|
||||
max_tokens: int = 1024
|
||||
lm: str = "gemini/gemini-2.5-pro-preview-03-25"
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.7
|
||||
|
||||
|
||||
class PromptToSignatureAgent(PrecompiledAgent):
|
||||
config: PromptToSignatureConfig
|
||||
|
||||
def __init__(self, config: PromptToSignatureConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.signature_generator = SignatureGenerator()
|
||||
|
||||
def forward(self, prompt: str) -> str:
|
||||
return "hello world"
|
||||
lm = dspy.LM(
|
||||
model=config.lm,
|
||||
max_tokens=config.max_tokens,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
|
||||
self.signature_generator.set_lm(lm)
|
||||
|
||||
def forward(self, prompt: str, as_dict: bool = False) -> dspy.Prediction:
|
||||
# returns dspy.Prediction object or dict
|
||||
return (
|
||||
self.signature_generator.generate_signature(prompt)
|
||||
if as_dict
|
||||
else self.signature_generator(prompt)
|
||||
)
|
||||
|
||||
|
||||
agent = PromptToSignatureAgent(PromptToSignatureConfig())
|
||||
|
||||
|
||||
def main():
|
||||
agent.push_to_hub("fadeleke/prompt-to-signature", with_code=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user