Fix config override bug by recreating LMs after load_state
This commit is contained in:
54
nanocode.py
54
nanocode.py
@@ -23,7 +23,6 @@ MAGENTA = "\033[35m"
|
|||||||
|
|
||||||
# --- File operations ---
|
# --- File operations ---
|
||||||
|
|
||||||
|
|
||||||
def read_file(path: str, offset: int = 0, limit: int = None) -> str:
|
def read_file(path: str, offset: int = 0, limit: int = None) -> str:
|
||||||
"""Read file contents with line numbers.
|
"""Read file contents with line numbers.
|
||||||
|
|
||||||
@@ -163,6 +162,9 @@ def run_bash(cmd: str) -> str:
|
|||||||
output_lines.append("\n(timed out after 30s)")
|
output_lines.append("\n(timed out after 30s)")
|
||||||
return "".join(output_lines).strip() or "(empty output)"
|
return "".join(output_lines).strip() or "(empty output)"
|
||||||
|
|
||||||
|
|
||||||
|
# -- Program ---
|
||||||
|
|
||||||
class CodingAssistant(dspy.Signature):
|
class CodingAssistant(dspy.Signature):
|
||||||
"""You are a concise coding assistant with access to sub agents."""
|
"""You are a concise coding assistant with access to sub agents."""
|
||||||
|
|
||||||
@@ -171,6 +173,7 @@ class CodingAssistant(dspy.Signature):
|
|||||||
desc="Your response to the user after completing the task"
|
desc="Your response to the user after completing the task"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RLMCodingConfig(PrecompiledConfig):
|
class RLMCodingConfig(PrecompiledConfig):
|
||||||
max_iters: int = 50
|
max_iters: int = 50
|
||||||
lm: str = "openrouter/openai/gpt-5.2-codex"
|
lm: str = "openrouter/openai/gpt-5.2-codex"
|
||||||
@@ -210,7 +213,7 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
max_tokens=self.config.max_tokens,
|
max_tokens=self.config.max_tokens,
|
||||||
track_usage=self.config.track_usage,
|
track_usage=self.config.track_usage,
|
||||||
)
|
)
|
||||||
agent = dspy.RLM(
|
self.agent = dspy.RLM(
|
||||||
CodingAssistant,
|
CodingAssistant,
|
||||||
sub_lm=self.sub_lm,
|
sub_lm=self.sub_lm,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
@@ -218,12 +221,13 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
max_iterations=self.config.max_iters,
|
max_iterations=self.config.max_iters,
|
||||||
verbose=self.config.verbose,
|
verbose=self.config.verbose,
|
||||||
)
|
)
|
||||||
|
self.agent.set_lm(self.lm)
|
||||||
agent.set_lm(self.lm)
|
self.set_lm(self.lm)
|
||||||
self.agent = agent
|
|
||||||
|
|
||||||
def forward(self, task: str) -> str:
|
def forward(self, task: str) -> str:
|
||||||
assert task, "Task cannot be empty"
|
if not task:
|
||||||
|
return dspy.Prediction(answer="No Task Given.")
|
||||||
|
|
||||||
return self.agent(task=task)
|
return self.agent(task=task)
|
||||||
|
|
||||||
def get_tools(self):
|
def get_tools(self):
|
||||||
@@ -231,14 +235,14 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
|
|
||||||
def set_tool(self, name: str, tool: callable):
|
def set_tool(self, name: str, tool: callable):
|
||||||
self.tools[name] = tool
|
self.tools[name] = tool
|
||||||
self.reload_repl_tools()
|
self.reload_repl()
|
||||||
|
|
||||||
def remove_tool(self, name: str):
|
def remove_tool(self, name: str):
|
||||||
if name in self.tools:
|
if name in self.tools:
|
||||||
del self.tools[name]
|
del self.tools[name]
|
||||||
self.reload_repl_tools()
|
self.reload_repl()
|
||||||
|
|
||||||
def reload_repl_tools(
|
def reload_repl(
|
||||||
self,
|
self,
|
||||||
): # we need to create a new instance for tool mutations to be passed back into the REPL
|
): # we need to create a new instance for tool mutations to be passed back into the REPL
|
||||||
new_instance = dspy.RLM(
|
new_instance = dspy.RLM(
|
||||||
@@ -252,6 +256,36 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
new_instance.set_lm(self.lm)
|
new_instance.set_lm(self.lm)
|
||||||
self.agent = new_instance
|
self.agent = new_instance
|
||||||
|
|
||||||
|
def reload_lms(self):
|
||||||
|
"""Recreate LM objects from current config. Call this after changing config.lm or config.sub_lm."""
|
||||||
|
self.lm = dspy.LM(
|
||||||
|
model=self.config.lm,
|
||||||
|
api_base=self.config.api_base,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
track_usage=self.config.track_usage,
|
||||||
|
)
|
||||||
|
self.sub_lm = dspy.LM(
|
||||||
|
model=self.config.sub_lm,
|
||||||
|
api_base=self.config.api_base,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
track_usage=self.config.track_usage,
|
||||||
|
)
|
||||||
|
self.reload_repl()
|
||||||
|
if os.getenv("MODAIC_ENV") == "dev":
|
||||||
|
print(f"{BLUE}LMs RELOADED: {self.lm.model}, {self.sub_lm.model}{RESET}")
|
||||||
|
|
||||||
|
def load_state(self, state):
|
||||||
|
"""Override to recreate LMs from config after loading state.
|
||||||
|
|
||||||
|
PrecompiledProgram.from_precompiled() calls load_state() AFTER __init__,
|
||||||
|
which overwrites our LMs with saved state. We fix this by recreating
|
||||||
|
the LMs from self.config after the parent load_state runs. Modaic will
|
||||||
|
fix this in a later patch for future devs.
|
||||||
|
"""
|
||||||
|
super().load_state(state)
|
||||||
|
# recreate LMs from config (not from saved state)
|
||||||
|
self.reload_lms()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
agent = RLMCodingProgram(RLMCodingConfig())
|
agent = RLMCodingProgram(RLMCodingConfig())
|
||||||
agent.push_to_hub(MODAIC_REPO_PATH, commit_message="change signature")
|
agent.push_to_hub(MODAIC_REPO_PATH, commit_message="Fix config override bug by recreating LMs after load_state")
|
||||||
|
|||||||
Reference in New Issue
Block a user