(no commit message)
This commit is contained in:
4
auto_classes.json
Normal file
4
auto_classes.json
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"AutoConfig": "main.CodeGeneratorConfig",
|
||||||
|
"AutoProgram": "main.CodeGenerator"
|
||||||
|
}
|
||||||
37
main.py
Normal file
37
main.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import dspy
|
||||||
|
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||||
|
|
||||||
|
class CodeGeneratorConfig(PrecompiledConfig):
|
||||||
|
model : str = "openai//models/checkpoint-5"
|
||||||
|
api_base : str = "https://modaic-ai--grpo-demo-serve.modal.run/v1"
|
||||||
|
max_tokens : int = 1024
|
||||||
|
temperature : float = 0.7
|
||||||
|
|
||||||
|
|
||||||
|
class CodeGeneration(dspy.Signature):
|
||||||
|
query: str = dspy.InputField(desc="The query to generate code for.")
|
||||||
|
code: str = dspy.OutputField(desc="The code to generate as a python function.")
|
||||||
|
|
||||||
|
class CodeGenerator(PrecompiledProgram):
|
||||||
|
config: CodeGeneratorConfig
|
||||||
|
|
||||||
|
def __init__(self, config: CodeGeneratorConfig, **kwargs):
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
|
||||||
|
modal_lm = dspy.LM(
|
||||||
|
model=config.model,
|
||||||
|
api_base=config.api_base,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
temperature=config.temperature,
|
||||||
|
)
|
||||||
|
self.answer_question = dspy.Predict(CodeGeneration)
|
||||||
|
self.answer_question.set_lm(modal_lm)
|
||||||
|
|
||||||
|
def forward(self, query):
|
||||||
|
return self.answer_question(query=query)
|
||||||
|
|
||||||
|
code_generator = CodeGenerator(CodeGeneratorConfig())
|
||||||
|
print(code_generator(query="Write a python function that returns the sum of two numbers.").code)
|
||||||
|
code_generator.push_to_hub("modaic/code-generator-trl-grpo", with_code=True, tag="v1")
|
||||||
|
|
||||||
|
|
||||||
7
pyproject.toml
Normal file
7
pyproject.toml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[project]
|
||||||
|
name = "code-generator-trl-grpo"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
dependencies = ["dotenv>=0.9.9", "dspy>=3.0.4", "modaic>=0.8.2", "modal>=1.3.0.post1"]
|
||||||
Reference in New Issue
Block a user