Fix config override bug by recreating LMs after load_state
This commit is contained in:
22
nanocode.py
22
nanocode.py
@@ -222,19 +222,10 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
)
|
)
|
||||||
self.agent.set_lm(self.lm)
|
self.agent.set_lm(self.lm)
|
||||||
self.set_lm(self.lm)
|
self.set_lm(self.lm)
|
||||||
print(f"{BLUE}CONFIG WAS SET: {self.config}{RESET}")
|
|
||||||
print(f"{BLUE}LM WAS SET: {self.agent.get_lm()}{RESET}")
|
|
||||||
print(f"{BLUE}SUB LM WAS SET: {self.agent.sub_lm.model}{RESET}")
|
|
||||||
|
|
||||||
def forward(self, task: str) -> str:
|
def forward(self, task: str) -> str:
|
||||||
if not task:
|
if not task:
|
||||||
return dspy.Prediction(answer="No Task Given.")
|
return dspy.Prediction(answer="No Task Given.")
|
||||||
|
|
||||||
print(f"{YELLOW}DEBUG forward() - agent.get_lm(): {self.agent.get_lm()}{RESET}")
|
|
||||||
print(f"{YELLOW}DEBUG forward() - self.lm.model: {self.lm.model}{RESET}")
|
|
||||||
print(f"{YELLOW}DEBUG forward() - agent.sub_lm.model: {self.agent.sub_lm.model}{RESET}")
|
|
||||||
print(f"{YELLOW}DEBUG forward() - self.sub_lm.model: {self.sub_lm.model}{RESET}")
|
|
||||||
print(f"{YELLOW}DEBUG forward() - id(agent.get_lm()): {id(self.agent.get_lm())}, id(self.lm): {id(self.lm)}{RESET}")
|
|
||||||
|
|
||||||
return self.agent(task=task)
|
return self.agent(task=task)
|
||||||
|
|
||||||
@@ -281,6 +272,17 @@ class RLMCodingProgram(PrecompiledProgram):
|
|||||||
self.reload_repl()
|
self.reload_repl()
|
||||||
print(f"{BLUE}LMs RELOADED: {self.lm.model}, {self.sub_lm.model}{RESET}")
|
print(f"{BLUE}LMs RELOADED: {self.lm.model}, {self.sub_lm.model}{RESET}")
|
||||||
|
|
||||||
|
def load_state(self, state):
|
||||||
|
"""Override to recreate LMs from config after loading state.
|
||||||
|
|
||||||
|
PrecompiledProgram.from_precompiled() calls load_state() AFTER __init__,
|
||||||
|
which overwrites our LMs with saved state. We fix this by recreating
|
||||||
|
the LMs from self.config after the parent load_state runs.
|
||||||
|
"""
|
||||||
|
super().load_state(state)
|
||||||
|
# Recreate LMs from config (not from saved state)
|
||||||
|
self.reload_lms()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
agent = RLMCodingProgram(RLMCodingConfig())
|
agent = RLMCodingProgram(RLMCodingConfig())
|
||||||
agent.push_to_hub(MODAIC_REPO_PATH, commit_message="Add reload_lms method and debug forward()", branch="dev")
|
agent.push_to_hub(MODAIC_REPO_PATH, commit_message="Fix config override bug by recreating LMs after load_state", branch="dev")
|
||||||
|
|||||||
Reference in New Issue
Block a user