Fix config override bug by recreating LMs after load_state

This commit is contained in:
2026-01-24 14:45:21 -08:00
parent ad1d95950e
commit 68c4115454

View File

@@ -170,7 +170,8 @@ def run_bash(cmd: str) -> str:
class RLMReasoningCallback(BaseCallback): class RLMReasoningCallback(BaseCallback):
def on_module_end(self, call_id, outputs, exception): def on_module_end(self, call_id, outputs, exception):
if outputs and hasattr(outputs, "reasoning") and hasattr(outputs, "code"): if outputs and hasattr(outputs, "reasoning") and hasattr(outputs, "code"):
print(f"\n[REASONING STEP] reasoning: {outputs.reasoning}\n") print(f"{DIM}[REASONING STEP] {outputs.reasoning}\n{RESET}")
print(f"{DIM}⏺ [CODE] {outputs.code}\n{RESET}")
# -- Program --- # -- Program ---
@@ -230,25 +231,77 @@ class RLMCodingProgram(PrecompiledProgram):
tools=self.tools, tools=self.tools,
max_output_chars=self.config.max_output_chars, max_output_chars=self.config.max_output_chars,
max_iterations=self.config.max_iters, max_iterations=self.config.max_iters,
verbose=self.config.verbose,
) )
self.agent.set_lm(self.lm) self.agent.set_lm(self.lm)
self.set_lm(self.lm)
if self.config.verbose:
self.add_logging_callbacks()
def add_logging_callbacks(self):
"""Add logging callbacks to the agent."""
self.agent.generate_action.callbacks.append(RLMReasoningCallback())
self._patch_llm_tools()
def _patch_llm_tools(self):
"""Monkey-patch the RLM's _make_llm_tools to add structured verbose logging."""
orig_factory = (
self.agent._make_llm_tools
) # capture the original bound method directly
def verbose_factory(max_workers=8):
tools = orig_factory(
max_workers=max_workers
) # call the original bound method
orig_q = tools["llm_query"]
orig_b = tools["llm_query_batched"]
def wrapped_q(prompt): # wrap query
print(
f"{DIM}⏺ [LLM QUERY]: {prompt[:100]}...{RESET}\n"
if len(prompt) > 100
else f"{DIM}⏺ [LLM QUERY]: {prompt}{RESET}\n"
)
res = orig_q(prompt)
print(
f"{DIM}⏺ [LLM QUERY RESULT]: {str(res)[:200]}...{RESET}\n"
if len(str(res)) > 200
else f"{DIM}⏺ [LLM QUERY RESULT]: {res}{RESET}\n"
)
return res
def wrapped_b(prompts): # wrap batched query
print(f"{DIM}⏺ [LLM QUERY BATCHED]: {len(prompts)} prompts{RESET}\n")
res = orig_b(prompts)
print(f"{DIM}⏺ [LLM QUERY BATCHED]: {len(res)} results{RESET}\n")
return res
tools["llm_query"] = wrapped_q
tools["llm_query_batched"] = wrapped_b
return tools
self.agent._make_llm_tools = verbose_factory
def forward(self, task: str) -> str: def forward(self, task: str) -> str:
"""Forward pass for the agent."""
if not task: if not task:
return dspy.Prediction(answer="No Task Given.") return dspy.Prediction(answer="No Task Given.")
return self.agent(task=task) return self.agent(task=task)
def get_tools(self): def get_tools(self):
"""Get the tools for the agent."""
return self.tools return self.tools
def set_tool(self, name: str, tool: callable): def set_tool(self, name: str, tool: callable):
"""Set a tool for the agent."""
self.tools[name] = tool self.tools[name] = tool
self.reload_repl() self.reload_repl()
def remove_tool(self, name: str): def remove_tool(self, name: str):
"""Remove a tool from the agent."""
if name in self.tools: if name in self.tools:
del self.tools[name] del self.tools[name]
self.reload_repl() self.reload_repl()
@@ -256,6 +309,8 @@ class RLMCodingProgram(PrecompiledProgram):
def reload_repl( def reload_repl(
self, self,
): # we need to create a new instance for tool mutations to be passed back into the REPL ): # we need to create a new instance for tool mutations to be passed back into the REPL
"""Reload the REPL with the current tools."""
new_instance = dspy.RLM( new_instance = dspy.RLM(
CodingAssistant, CodingAssistant,
sub_lm=self.sub_lm, sub_lm=self.sub_lm,
@@ -266,9 +321,12 @@ class RLMCodingProgram(PrecompiledProgram):
) )
new_instance.set_lm(self.lm) new_instance.set_lm(self.lm)
self.agent = new_instance self.agent = new_instance
if self.config.verbose:
self.add_logging_callbacks()
def reload_lms(self): def reload_lms(self):
"""Recreate LM objects from current config. Call this after changing config.lm or config.sub_lm.""" """Recreate LM objects from current config. Call this after changing config.lm or config.sub_lm."""
self.lm = dspy.LM( self.lm = dspy.LM(
model=self.config.lm, model=self.config.lm,
api_base=self.config.api_base, api_base=self.config.api_base,
@@ -294,13 +352,19 @@ class RLMCodingProgram(PrecompiledProgram):
fix this in a later patch for future devs. fix this in a later patch for future devs.
""" """
super().load_state(state) super().load_state(state)
# recreate LMs from config (not from saved state) self.reload_lms() # recreate LMs from config (not from saved state)
self.reload_lms()
if __name__ == "__main__": if __name__ == "__main__":
agent = RLMCodingProgram(RLMCodingConfig()) agent = RLMCodingProgram(RLMCodingConfig())
agent.push_to_hub( #displayed = agent(
MODAIC_REPO_PATH, # task="Use only the command `print(llm_query('Who is the CEO of Apple?'))` to find the answer."
commit_message="Fix config override bug by recreating LMs after load_state", #)
) branches = ["main", "dev", "prod"]
for branch in branches:
agent.push_to_hub(
MODAIC_REPO_PATH,
commit_message="Fix config override bug by recreating LMs after load_state",
branch=branch,
)