Fix config override bug by recreating LMs after load_state
This commit is contained in:
76
nanocode.py
76
nanocode.py
@@ -170,7 +170,8 @@ def run_bash(cmd: str) -> str:
|
||||
class RLMReasoningCallback(BaseCallback):
|
||||
def on_module_end(self, call_id, outputs, exception):
|
||||
if outputs and hasattr(outputs, "reasoning") and hasattr(outputs, "code"):
|
||||
print(f"\n[REASONING STEP] reasoning: {outputs.reasoning}\n")
|
||||
print(f"{DIM}⏺ [REASONING STEP] {outputs.reasoning}\n{RESET}")
|
||||
print(f"{DIM}⏺ [CODE] {outputs.code}\n{RESET}")
|
||||
|
||||
|
||||
# -- Program ---
|
||||
@@ -230,25 +231,77 @@ class RLMCodingProgram(PrecompiledProgram):
|
||||
tools=self.tools,
|
||||
max_output_chars=self.config.max_output_chars,
|
||||
max_iterations=self.config.max_iters,
|
||||
verbose=self.config.verbose,
|
||||
)
|
||||
self.agent.set_lm(self.lm)
|
||||
self.set_lm(self.lm)
|
||||
|
||||
if self.config.verbose:
|
||||
self.add_logging_callbacks()
|
||||
|
||||
def add_logging_callbacks(self):
|
||||
"""Add logging callbacks to the agent."""
|
||||
|
||||
self.agent.generate_action.callbacks.append(RLMReasoningCallback())
|
||||
self._patch_llm_tools()
|
||||
|
||||
def _patch_llm_tools(self):
|
||||
"""Monkey-patch the RLM's _make_llm_tools to add structured verbose logging."""
|
||||
|
||||
orig_factory = (
|
||||
self.agent._make_llm_tools
|
||||
) # capture the original bound method directly
|
||||
|
||||
def verbose_factory(max_workers=8):
|
||||
tools = orig_factory(
|
||||
max_workers=max_workers
|
||||
) # call the original bound method
|
||||
|
||||
orig_q = tools["llm_query"]
|
||||
orig_b = tools["llm_query_batched"]
|
||||
|
||||
def wrapped_q(prompt): # wrap query
|
||||
print(
|
||||
f"{DIM}⏺ [LLM QUERY]: {prompt[:100]}...{RESET}\n"
|
||||
if len(prompt) > 100
|
||||
else f"{DIM}⏺ [LLM QUERY]: {prompt}{RESET}\n"
|
||||
)
|
||||
res = orig_q(prompt)
|
||||
print(
|
||||
f"{DIM}⏺ [LLM QUERY RESULT]: {str(res)[:200]}...{RESET}\n"
|
||||
if len(str(res)) > 200
|
||||
else f"{DIM}⏺ [LLM QUERY RESULT]: {res}{RESET}\n"
|
||||
)
|
||||
return res
|
||||
|
||||
def wrapped_b(prompts): # wrap batched query
|
||||
print(f"{DIM}⏺ [LLM QUERY BATCHED]: {len(prompts)} prompts{RESET}\n")
|
||||
res = orig_b(prompts)
|
||||
print(f"{DIM}⏺ [LLM QUERY BATCHED]: {len(res)} results{RESET}\n")
|
||||
return res
|
||||
|
||||
tools["llm_query"] = wrapped_q
|
||||
tools["llm_query_batched"] = wrapped_b
|
||||
return tools
|
||||
|
||||
self.agent._make_llm_tools = verbose_factory
|
||||
|
||||
def forward(self, task: str) -> str:
|
||||
"""Forward pass for the agent."""
|
||||
if not task:
|
||||
return dspy.Prediction(answer="No Task Given.")
|
||||
|
||||
return self.agent(task=task)
|
||||
|
||||
def get_tools(self):
|
||||
"""Get the tools for the agent."""
|
||||
return self.tools
|
||||
|
||||
def set_tool(self, name: str, tool: callable):
|
||||
"""Set a tool for the agent."""
|
||||
self.tools[name] = tool
|
||||
self.reload_repl()
|
||||
|
||||
def remove_tool(self, name: str):
|
||||
"""Remove a tool from the agent."""
|
||||
if name in self.tools:
|
||||
del self.tools[name]
|
||||
self.reload_repl()
|
||||
@@ -256,6 +309,8 @@ class RLMCodingProgram(PrecompiledProgram):
|
||||
def reload_repl(
|
||||
self,
|
||||
): # we need to create a new instance for tool mutations to be passed back into the REPL
|
||||
"""Reload the REPL with the current tools."""
|
||||
|
||||
new_instance = dspy.RLM(
|
||||
CodingAssistant,
|
||||
sub_lm=self.sub_lm,
|
||||
@@ -266,9 +321,12 @@ class RLMCodingProgram(PrecompiledProgram):
|
||||
)
|
||||
new_instance.set_lm(self.lm)
|
||||
self.agent = new_instance
|
||||
if self.config.verbose:
|
||||
self.add_logging_callbacks()
|
||||
|
||||
def reload_lms(self):
|
||||
"""Recreate LM objects from current config. Call this after changing config.lm or config.sub_lm."""
|
||||
|
||||
self.lm = dspy.LM(
|
||||
model=self.config.lm,
|
||||
api_base=self.config.api_base,
|
||||
@@ -294,15 +352,19 @@ class RLMCodingProgram(PrecompiledProgram):
|
||||
fix this in a later patch for future devs.
|
||||
"""
|
||||
super().load_state(state)
|
||||
# recreate LMs from config (not from saved state)
|
||||
self.reload_lms()
|
||||
self.reload_lms() # recreate LMs from config (not from saved state)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
agent = RLMCodingProgram(RLMCodingConfig())
|
||||
#displayed = agent(
|
||||
# task="Use only the command `print(llm_query('Who is the CEO of Apple?'))` to find the answer."
|
||||
#)
|
||||
branches = ["main", "dev", "prod"]
|
||||
for branch in branches:
|
||||
agent.push_to_hub(
|
||||
MODAIC_REPO_PATH,
|
||||
commit_message="Fix config override bug by recreating LMs after load_state",
|
||||
branch="dev"
|
||||
|
||||
branch=branch,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user