(no commit message)
This commit is contained in:
4
auto_classes.json
Normal file
4
auto_classes.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"AutoConfig": "main.CodeGeneratorConfig",
|
||||
"AutoProgram": "main.CodeGenerator"
|
||||
}
|
||||
37
main.py
Normal file
37
main.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import dspy
|
||||
from modaic import PrecompiledProgram, PrecompiledConfig
|
||||
|
||||
class CodeGeneratorConfig(PrecompiledConfig):
|
||||
model : str = "openai//models/checkpoint-5"
|
||||
api_base : str = "https://modaic-ai--grpo-demo-serve.modal.run/v1"
|
||||
max_tokens : int = 1024
|
||||
temperature : float = 0.7
|
||||
|
||||
|
||||
class CodeGeneration(dspy.Signature):
|
||||
query: str = dspy.InputField(desc="The query to generate code for.")
|
||||
code: str = dspy.OutputField(desc="The code to generate as a python function.")
|
||||
|
||||
class CodeGenerator(PrecompiledProgram):
|
||||
config: CodeGeneratorConfig
|
||||
|
||||
def __init__(self, config: CodeGeneratorConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
|
||||
modal_lm = dspy.LM(
|
||||
model=config.model,
|
||||
api_base=config.api_base,
|
||||
max_tokens=config.max_tokens,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
self.answer_question = dspy.Predict(CodeGeneration)
|
||||
self.answer_question.set_lm(modal_lm)
|
||||
|
||||
def forward(self, query):
|
||||
return self.answer_question(query=query)
|
||||
|
||||
code_generator = CodeGenerator(CodeGeneratorConfig())
|
||||
print(code_generator(query="Write a python function that returns the sum of two numbers.").code)
|
||||
code_generator.push_to_hub("modaic/code-generator-trl-grpo", with_code=True, tag="v1")
|
||||
|
||||
|
||||
7
pyproject.toml
Normal file
7
pyproject.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[project]
|
||||
name = "code-generator-trl-grpo"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = ["dotenv>=0.9.9", "dspy>=3.0.4", "modaic>=0.8.2", "modal>=1.3.0.post1"]
|
||||
Reference in New Issue
Block a user