Fix config override bug by recreating LMs after load_state
This commit is contained in:
75
nanocode.py
75
nanocode.py
@@ -170,7 +170,8 @@ def run_bash(cmd: str) -> str:
|
|||||||
class RLMReasoningCallback(BaseCallback):
|
class RLMReasoningCallback(BaseCallback):
|
||||||
def on_module_end(self, call_id, outputs, exception):
|
def on_module_end(self, call_id, outputs, exception):
|
||||||
if outputs and hasattr(outputs, "reasoning") and hasattr(outputs, "code"):
|
if outputs and hasattr(outputs, "reasoning") and hasattr(outputs, "code"):
|
||||||
print(f"\n[REASONING STEP] reasoning: {outputs.reasoning}\n")
|
print(f"{DIM}⏺ [REASONING STEP] {outputs.reasoning}\n{RESET}")
|
||||||
|
print(f"{DIM}⏺ [CODE] {outputs.code}\n{RESET}")
|
||||||
|
|
||||||
|
|
||||||
# -- Program ---
|
# -- Program ---
|
||||||
@@ -230,25 +231,77 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
max_output_chars=self.config.max_output_chars,
|
max_output_chars=self.config.max_output_chars,
|
||||||
max_iterations=self.config.max_iters,
|
max_iterations=self.config.max_iters,
|
||||||
verbose=self.config.verbose,
|
|
||||||
)
|
)
|
||||||
self.agent.set_lm(self.lm)
|
self.agent.set_lm(self.lm)
|
||||||
self.set_lm(self.lm)
|
|
||||||
|
if self.config.verbose:
|
||||||
|
self.add_logging_callbacks()
|
||||||
|
|
||||||
|
def add_logging_callbacks(self):
|
||||||
|
"""Add logging callbacks to the agent."""
|
||||||
|
|
||||||
|
self.agent.generate_action.callbacks.append(RLMReasoningCallback())
|
||||||
|
self._patch_llm_tools()
|
||||||
|
|
||||||
|
def _patch_llm_tools(self):
|
||||||
|
"""Monkey-patch the RLM's _make_llm_tools to add structured verbose logging."""
|
||||||
|
|
||||||
|
orig_factory = (
|
||||||
|
self.agent._make_llm_tools
|
||||||
|
) # capture the original bound method directly
|
||||||
|
|
||||||
|
def verbose_factory(max_workers=8):
|
||||||
|
tools = orig_factory(
|
||||||
|
max_workers=max_workers
|
||||||
|
) # call the original bound method
|
||||||
|
|
||||||
|
orig_q = tools["llm_query"]
|
||||||
|
orig_b = tools["llm_query_batched"]
|
||||||
|
|
||||||
|
def wrapped_q(prompt): # wrap query
|
||||||
|
print(
|
||||||
|
f"{DIM}⏺ [LLM QUERY]: {prompt[:100]}...{RESET}\n"
|
||||||
|
if len(prompt) > 100
|
||||||
|
else f"{DIM}⏺ [LLM QUERY]: {prompt}{RESET}\n"
|
||||||
|
)
|
||||||
|
res = orig_q(prompt)
|
||||||
|
print(
|
||||||
|
f"{DIM}⏺ [LLM QUERY RESULT]: {str(res)[:200]}...{RESET}\n"
|
||||||
|
if len(str(res)) > 200
|
||||||
|
else f"{DIM}⏺ [LLM QUERY RESULT]: {res}{RESET}\n"
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def wrapped_b(prompts): # wrap batched query
|
||||||
|
print(f"{DIM}⏺ [LLM QUERY BATCHED]: {len(prompts)} prompts{RESET}\n")
|
||||||
|
res = orig_b(prompts)
|
||||||
|
print(f"{DIM}⏺ [LLM QUERY BATCHED]: {len(res)} results{RESET}\n")
|
||||||
|
return res
|
||||||
|
|
||||||
|
tools["llm_query"] = wrapped_q
|
||||||
|
tools["llm_query_batched"] = wrapped_b
|
||||||
|
return tools
|
||||||
|
|
||||||
|
self.agent._make_llm_tools = verbose_factory
|
||||||
|
|
||||||
def forward(self, task: str) -> str:
|
def forward(self, task: str) -> str:
|
||||||
|
"""Forward pass for the agent."""
|
||||||
if not task:
|
if not task:
|
||||||
return dspy.Prediction(answer="No Task Given.")
|
return dspy.Prediction(answer="No Task Given.")
|
||||||
|
|
||||||
return self.agent(task=task)
|
return self.agent(task=task)
|
||||||
|
|
||||||
def get_tools(self):
|
def get_tools(self):
|
||||||
|
"""Get the tools for the agent."""
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
def set_tool(self, name: str, tool: callable):
|
def set_tool(self, name: str, tool: callable):
|
||||||
|
"""Set a tool for the agent."""
|
||||||
self.tools[name] = tool
|
self.tools[name] = tool
|
||||||
self.reload_repl()
|
self.reload_repl()
|
||||||
|
|
||||||
def remove_tool(self, name: str):
|
def remove_tool(self, name: str):
|
||||||
|
"""Remove a tool from the agent."""
|
||||||
if name in self.tools:
|
if name in self.tools:
|
||||||
del self.tools[name]
|
del self.tools[name]
|
||||||
self.reload_repl()
|
self.reload_repl()
|
||||||
@@ -256,6 +309,8 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
def reload_repl(
|
def reload_repl(
|
||||||
self,
|
self,
|
||||||
): # we need to create a new instance for tool mutations to be passed back into the REPL
|
): # we need to create a new instance for tool mutations to be passed back into the REPL
|
||||||
|
"""Reload the REPL with the current tools."""
|
||||||
|
|
||||||
new_instance = dspy.RLM(
|
new_instance = dspy.RLM(
|
||||||
CodingAssistant,
|
CodingAssistant,
|
||||||
sub_lm=self.sub_lm,
|
sub_lm=self.sub_lm,
|
||||||
@@ -266,9 +321,12 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
)
|
)
|
||||||
new_instance.set_lm(self.lm)
|
new_instance.set_lm(self.lm)
|
||||||
self.agent = new_instance
|
self.agent = new_instance
|
||||||
|
if self.config.verbose:
|
||||||
|
self.add_logging_callbacks()
|
||||||
|
|
||||||
def reload_lms(self):
|
def reload_lms(self):
|
||||||
"""Recreate LM objects from current config. Call this after changing config.lm or config.sub_lm."""
|
"""Recreate LM objects from current config. Call this after changing config.lm or config.sub_lm."""
|
||||||
|
|
||||||
self.lm = dspy.LM(
|
self.lm = dspy.LM(
|
||||||
model=self.config.lm,
|
model=self.config.lm,
|
||||||
api_base=self.config.api_base,
|
api_base=self.config.api_base,
|
||||||
@@ -294,14 +352,19 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
fix this in a later patch for future devs.
|
fix this in a later patch for future devs.
|
||||||
"""
|
"""
|
||||||
super().load_state(state)
|
super().load_state(state)
|
||||||
# recreate LMs from config (not from saved state)
|
self.reload_lms() # recreate LMs from config (not from saved state)
|
||||||
self.reload_lms()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
agent = RLMCodingProgram(RLMCodingConfig())
|
agent = RLMCodingProgram(RLMCodingConfig())
|
||||||
|
#displayed = agent(
|
||||||
|
# task="Use only the command `print(llm_query('Who is the CEO of Apple?'))` to find the answer."
|
||||||
|
#)
|
||||||
|
branches = ["main", "dev", "prod"]
|
||||||
|
for branch in branches:
|
||||||
agent.push_to_hub(
|
agent.push_to_hub(
|
||||||
MODAIC_REPO_PATH,
|
MODAIC_REPO_PATH,
|
||||||
commit_message="Fix config override bug by recreating LMs after load_state",
|
commit_message="Fix config override bug by recreating LMs after load_state",
|
||||||
branch="prod"
|
branch=branch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user