(no commit message)
This commit is contained in:
37
main.py
Normal file
37
main.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import dspy
|
||||
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||
|
||||
class CodeGeneratorConfig(PrecompiledConfig):
|
||||
model : str = "openai//models/checkpoint-5"
|
||||
api_base : str = "https://modaic-ai--grpo-demo-serve.modal.run/v1"
|
||||
max_tokens : int = 1024
|
||||
temperature : float = 0.7
|
||||
|
||||
|
||||
class CodeGeneration(dspy.Signature):
|
||||
query: str = dspy.InputField(desc="The query to generate code for.")
|
||||
code: str = dspy.OutputField(desc="The code to generate as a python function.")
|
||||
|
||||
class CodeGenerator(PrecompiledProgram):
|
||||
config: CodeGeneratorConfig
|
||||
|
||||
def __init__(self, config: CodeGeneratorConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
|
||||
modal_lm = dspy.LM(
|
||||
model=config.model,
|
||||
api_base=config.api_base,
|
||||
max_tokens=config.max_tokens,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
self.answer_question = dspy.Predict(CodeGeneration)
|
||||
self.answer_question.set_lm(modal_lm)
|
||||
|
||||
def forward(self, query):
|
||||
return self.answer_question(query=query)
|
||||
|
||||
code_generator = CodeGenerator(CodeGeneratorConfig())
|
||||
print(code_generator(query="Write a python function that returns the sum of two numbers.").code)
|
||||
code_generator.push_to_hub("modaic/code-generator-trl-grpo", with_code=True, tag="v1")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user