Add reload_lms method and debug forward()

This commit is contained in:
2026-01-24 00:29:10 -08:00
parent a4ae97ef81
commit 1a005b6584

View File

@@ -229,6 +229,14 @@ class RLMCodingProgram(PrecompiledProgram):
if not task: if not task:
return dspy.Prediction(answer="No Task Given.") return dspy.Prediction(answer="No Task Given.")
# Debug: verify LM at call time
actual_lm = self.agent.get_lm()
print(f"{YELLOW}DEBUG forward() - agent.get_lm().model: {actual_lm.model}{RESET}")
print(f"{YELLOW}DEBUG forward() - self.lm.model: {self.lm.model}{RESET}")
print(f"{YELLOW}DEBUG forward() - agent.sub_lm.model: {self.agent.sub_lm.model}{RESET}")
print(f"{YELLOW}DEBUG forward() - self.sub_lm.model: {self.sub_lm.model}{RESET}")
print(f"{YELLOW}DEBUG forward() - id(agent.get_lm()): {id(actual_lm)}, id(self.lm): {id(self.lm)}{RESET}")
return self.agent(task=task) return self.agent(task=task)
def get_tools(self): def get_tools(self):
@@ -257,6 +265,23 @@ class RLMCodingProgram(PrecompiledProgram):
new_instance.set_lm(self.lm) new_instance.set_lm(self.lm)
self.agent = new_instance self.agent = new_instance
def reload_lms(self):
"""Recreate LM objects from current config. Call this after changing config.lm or config.sub_lm."""
self.lm = dspy.LM(
model=self.config.lm,
api_base=self.config.api_base,
max_tokens=self.config.max_tokens,
track_usage=self.config.track_usage,
)
self.sub_lm = dspy.LM(
model=self.config.sub_lm,
api_base=self.config.api_base,
max_tokens=self.config.max_tokens,
track_usage=self.config.track_usage,
)
self.reload_repl()
print(f"{BLUE}LMs RELOADED: {self.lm.model}, {self.sub_lm.model}{RESET}")
if __name__ == "__main__": if __name__ == "__main__":
agent = RLMCodingProgram(RLMCodingConfig()) agent = RLMCodingProgram(RLMCodingConfig())
agent.push_to_hub(MODAIC_REPO_PATH, commit_message="change signature", branch="dev") agent.push_to_hub(MODAIC_REPO_PATH, commit_message="Add reload_lms method and debug forward()", branch="dev")