(no commit message)
This commit is contained in:
481
agent.py
Normal file
481
agent.py
Normal file
@@ -0,0 +1,481 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import dspy
|
||||
from litellm.exceptions import RateLimitError
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from modaic import PrecompiledConfig, PrecompiledProgram
|
||||
from config import SETTINGS
|
||||
from host_interpreter import UnsafeHostInterpreter
|
||||
from memory_fs import (
|
||||
mem_append_file,
|
||||
mem_create_directory,
|
||||
mem_directory_tree,
|
||||
mem_get_file_info,
|
||||
mem_list_directory,
|
||||
mem_move_file,
|
||||
mem_read_text_file,
|
||||
mem_search_files,
|
||||
mem_write_file,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class MinecraftFriendConfig(PrecompiledConfig):
|
||||
max_iterations: int = 12
|
||||
max_llm_calls: int = 18
|
||||
tools: dict[str, Callable[..., Any]] = {}
|
||||
lm: str = SETTINGS.main_model
|
||||
sub_lm: str = SETTINGS.sub_model
|
||||
verbose: bool = True
|
||||
|
||||
|
||||
class MinecraftFriendProgram(PrecompiledProgram):
|
||||
config: MinecraftFriendConfig
|
||||
|
||||
def __init__(self, config: MinecraftFriendConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
config = self.config
|
||||
self.rlm = dspy.RLM(
|
||||
signature=MinecraftFriendRLM,
|
||||
max_iterations=config.max_iterations,
|
||||
max_llm_calls=config.max_llm_calls,
|
||||
tools=config.tools,
|
||||
sub_lm=dspy.LM(config.sub_lm),
|
||||
verbose=config.verbose,
|
||||
interpreter=UnsafeHostInterpreter(),
|
||||
)
|
||||
self.rlm.set_lm(dspy.LM(config.lm))
|
||||
|
||||
def forward(self, chat, memory):
|
||||
return self.rlm(chat=chat, memory=memory)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentState:
|
||||
last_chat_fingerprint: str = ""
|
||||
last_spoke_at: float = 0.0
|
||||
last_decide_at: float = 0.0
|
||||
|
||||
|
||||
def extract_chat_lines(summary: str) -> list[str]:
|
||||
lines = [line.rstrip() for line in summary.splitlines()]
|
||||
if "==================" not in lines:
|
||||
return []
|
||||
idx = lines.index("==================")
|
||||
return [line for line in lines[idx + 1 :] if line.strip()]
|
||||
|
||||
|
||||
def drop_own_messages(lines: list[str], bot_username: str) -> list[str]:
|
||||
# Server duplicates bot speech in both "[System] <Bot> ..." and "<Bot>: ..."
|
||||
needle = f"<{bot_username}>"
|
||||
return [line for line in lines if needle not in line]
|
||||
|
||||
|
||||
def fingerprint(lines: list[str]) -> str:
|
||||
return "\n".join(lines[-30:])
|
||||
|
||||
|
||||
def _extract_retry_after_seconds(err: Exception) -> float | None:
|
||||
# Groq/LiteLLM error strings often include: "Please try again in 16.0575s."
|
||||
s = str(err)
|
||||
marker = "try again in "
|
||||
if marker not in s:
|
||||
return None
|
||||
tail = s.split(marker, 1)[1]
|
||||
num = ""
|
||||
for ch in tail:
|
||||
if ch.isdigit() or ch == ".":
|
||||
num += ch
|
||||
continue
|
||||
break
|
||||
try:
|
||||
return float(num) if num else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _calltool_text(call_tool_result) -> str:
|
||||
# Compatible with MCP SDK TextContent blocks.
|
||||
out: list[str] = []
|
||||
for block in getattr(call_tool_result, "content", []) or []:
|
||||
if getattr(block, "type", None) == "text":
|
||||
out.append(getattr(block, "text", ""))
|
||||
return "\n".join([t for t in out if t]).strip()
|
||||
|
||||
|
||||
class MinecraftFriendRLM(dspy.Signature):
|
||||
"""
|
||||
You are a friendly AI companion playing Minecraft with Paul.
|
||||
|
||||
Your ONLY way to talk is by calling MCP tools (especially `sendChat`).
|
||||
Use tools like `readChat`, `mineResource`, `lookAround`, etc. when useful.
|
||||
|
||||
The `response` output is only a short internal note about what you did.
|
||||
"""
|
||||
|
||||
chat = dspy.InputField(desc="Recent Minecraft chat lines (most recent last).")
|
||||
memory = dspy.InputField(desc="Short memory about Paul and the current goal.")
|
||||
response = dspy.OutputField(desc="Short internal note (not sent to chat).")
|
||||
|
||||
|
||||
def _tool_default_from_schema(schema: dict[str, Any]) -> Any:
|
||||
# JSON schema defaults are best-effort; they may be missing.
|
||||
return schema.get("default", inspect._empty)
|
||||
|
||||
|
||||
def _make_sync_mcp_tool(
|
||||
*,
|
||||
tool: dspy.Tool,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
on_call: Callable[[], None] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
Wrap an async MCP-backed `dspy.Tool` into a sync callable that can safely be used
|
||||
inside RLM code execution, even while the main asyncio loop is running.
|
||||
"""
|
||||
|
||||
arg_order = list((tool.args or {}).keys())
|
||||
|
||||
async def _acall(**kwargs: Any) -> Any:
|
||||
return await tool.acall(**kwargs)
|
||||
|
||||
def _sync(*args: Any, **kwargs: Any) -> Any:
|
||||
# Support common calling styles:
|
||||
# - tool(message="hi")
|
||||
# - tool("hi", delay=0) -> maps positional args in schema order
|
||||
# - tool({"message": "hi"}) -> dict-only positional
|
||||
if args:
|
||||
if len(args) == 1 and isinstance(args[0], dict) and not kwargs:
|
||||
kwargs = dict(args[0])
|
||||
else:
|
||||
for idx, value in enumerate(args):
|
||||
if idx >= len(arg_order):
|
||||
raise TypeError(
|
||||
f"{tool.name} got too many positional arguments"
|
||||
)
|
||||
kwargs.setdefault(arg_order[idx], value)
|
||||
fut = asyncio.run_coroutine_threadsafe(_acall(**kwargs), loop)
|
||||
result = fut.result()
|
||||
if on_call is not None:
|
||||
on_call()
|
||||
return result
|
||||
|
||||
_sync.__name__ = tool.name
|
||||
_sync.__doc__ = tool.desc or ""
|
||||
|
||||
# Give the LLM a nice signature in the RLM instructions.
|
||||
params: list[inspect.Parameter] = []
|
||||
for arg_name, schema in (tool.args or {}).items():
|
||||
default = _tool_default_from_schema(schema)
|
||||
params.append(
|
||||
inspect.Parameter(
|
||||
arg_name,
|
||||
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
_sync.__signature__ = inspect.Signature(parameters=params) # type: ignore[attr-defined]
|
||||
|
||||
return _sync
|
||||
|
||||
|
||||
def _parse_open_inventory(text: str) -> dict[str, int]:
|
||||
"""
|
||||
Parse openInventory() observation into a {item_name: count} dict.
|
||||
|
||||
Example input:
|
||||
"You just finished examining your inventory and it contains: 2 oak log, 2 birch log, 1 oak sapling."
|
||||
"""
|
||||
if "contains:" not in text:
|
||||
return {}
|
||||
tail = text.split("contains:", 1)[1].strip().rstrip(".")
|
||||
if not tail:
|
||||
return {}
|
||||
items: dict[str, int] = {}
|
||||
parts = [p.strip() for p in tail.split(",") if p.strip()]
|
||||
for p in parts:
|
||||
# "2 oak log" -> (2, "oak log")
|
||||
tokens = p.split()
|
||||
if not tokens:
|
||||
continue
|
||||
try:
|
||||
n = int(tokens[0])
|
||||
except ValueError:
|
||||
continue
|
||||
name = " ".join(tokens[1:]).strip().lower()
|
||||
if not name:
|
||||
continue
|
||||
items[name.replace(" ", "_")] = n
|
||||
return items
|
||||
|
||||
|
||||
async def main_async() -> None:
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--host", default=SETTINGS.mcp_minecraft_host)
|
||||
p.add_argument("--mc-port", type=int, default=SETTINGS.mcp_minecraft_port)
|
||||
p.add_argument("--bot-username", default=SETTINGS.bot_username)
|
||||
p.add_argument(
|
||||
"--validate-tools",
|
||||
action="store_true",
|
||||
help="Connect to the MCP server, list tools, convert them to dspy.Tool, then exit.",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
if not SETTINGS.groq_api_key:
|
||||
raise RuntimeError(
|
||||
"GROQ_API_KEY is not set. Copy .env.example to .env and fill it in."
|
||||
)
|
||||
os.environ.setdefault("GROQ_API_KEY", SETTINGS.groq_api_key)
|
||||
|
||||
# DSPy MCP tutorial requires dspy[mcp] and converts MCP tools via dspy.Tool.from_mcp_tool.
|
||||
# Important: DSPy's default RLM sandbox (Deno/Pyodide) cannot currently call tools in some
|
||||
# runtimes due to missing WASM stack switching. We use a host interpreter + sync tool wrappers.
|
||||
console.print(Panel(SETTINGS.main_model, title="DSPy model", border_style="cyan"))
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command="npx",
|
||||
args=[
|
||||
"-y",
|
||||
"--",
|
||||
"@fundamentallabs/minecraft-mcp",
|
||||
"-h",
|
||||
args.host,
|
||||
"-p",
|
||||
str(args.mc_port),
|
||||
],
|
||||
env=None,
|
||||
)
|
||||
|
||||
state = AgentState()
|
||||
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
|
||||
# Gather MCP tools and convert to DSPy tools (official DSPy MCP tutorial pattern).
|
||||
tools = await session.list_tools()
|
||||
dspy_tools = [dspy.Tool.from_mcp_tool(session, t) for t in tools.tools]
|
||||
# Add local "memory filesystem" tools (DSPy Tool wrappers).
|
||||
#
|
||||
# This follows DSPy's tool guidance: wrap functions in dspy.Tool and pass them via tools=...
|
||||
# https://dspy.ai/learn/programming/tools/
|
||||
memory_tools = [
|
||||
dspy.Tool(mem_list_directory),
|
||||
dspy.Tool(mem_read_text_file),
|
||||
dspy.Tool(mem_write_file),
|
||||
dspy.Tool(mem_append_file),
|
||||
dspy.Tool(mem_create_directory),
|
||||
dspy.Tool(mem_move_file),
|
||||
dspy.Tool(mem_search_files),
|
||||
dspy.Tool(mem_get_file_info),
|
||||
dspy.Tool(mem_directory_tree),
|
||||
]
|
||||
|
||||
all_tools = [*dspy_tools, *memory_tools]
|
||||
|
||||
if args.validate_tools:
|
||||
console.print(
|
||||
Panel(
|
||||
"\n".join([t.name for t in all_tools]),
|
||||
title=f"OK: ready {len(dspy_tools)} MCP tools + {len(memory_tools)} memory tools",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Build sync wrappers for MCP tools so the agent can call them inside RLM execution.
|
||||
loop = asyncio.get_running_loop()
|
||||
sync_mcp_tools: dict[str, Callable[..., Any]] = {}
|
||||
for t in dspy_tools:
|
||||
if not t.name.isidentifier():
|
||||
continue
|
||||
sync_mcp_tools[t.name] = _make_sync_mcp_tool(
|
||||
tool=t,
|
||||
loop=loop,
|
||||
on_call=(lambda: setattr(state, "last_spoke_at", time.time()))
|
||||
if t.name == "sendChat"
|
||||
else None,
|
||||
)
|
||||
|
||||
# Memory tools are already sync python callables.
|
||||
sync_memory_tools: dict[str, Callable[..., Any]] = {}
|
||||
for t in memory_tools:
|
||||
if not t.name.isidentifier():
|
||||
continue
|
||||
sync_memory_tools[t.name] = t
|
||||
|
||||
# High-level "agent guardrails" tools to reduce LLM confusion / regressions.
|
||||
def inv_counts() -> dict[str, int]:
|
||||
"""Return parsed inventory counts as a JSON-like dict."""
|
||||
text = sync_mcp_tools["openInventory"]()
|
||||
return _parse_open_inventory(str(text))
|
||||
|
||||
def have(item_name: str) -> int:
|
||||
"""Return how many of an item the bot currently has (best-effort)."""
|
||||
counts = inv_counts()
|
||||
return int(counts.get(item_name.lower(), 0))
|
||||
|
||||
def deliver_drop(user_name: str, item_name: str, count: int) -> str:
|
||||
"""Drop items near a player so they can pick them up (preferred transfer)."""
|
||||
if have(item_name) < count:
|
||||
return f"[ERROR] Not enough {item_name}. Have {have(item_name)}."
|
||||
return str(sync_mcp_tools["dropItem"](item_name, count, user_name))
|
||||
|
||||
def gather_to(
|
||||
item_name: str, target_count: int, batch: int = 8, max_rounds: int = 12
|
||||
) -> str:
|
||||
"""Iteratively mine until we have at least target_count of item_name (timeboxed)."""
|
||||
# Normalize common user phrasing.
|
||||
norm = item_name.strip().lower().replace(" ", "_")
|
||||
if norm in {"wood", "logs", "log"}:
|
||||
norm = "oak_log"
|
||||
for _ in range(max_rounds):
|
||||
cur = have(norm)
|
||||
if cur >= target_count:
|
||||
return f"OK: have {cur} {norm} (>= {target_count})"
|
||||
try:
|
||||
# Mine in small batches to reduce timeouts.
|
||||
sync_mcp_tools["mineResource"](
|
||||
norm, min(batch, max(target_count - cur, 1))
|
||||
)
|
||||
except Exception as e:
|
||||
return f"[ERROR] mineResource failed: {e}"
|
||||
return f"[WARN] Could not reach target. Have {have(norm)} {norm}."
|
||||
|
||||
helper_tools: dict[str, Callable[..., Any]] = {
|
||||
"inv_counts": inv_counts,
|
||||
"have": have,
|
||||
"gather_to": gather_to,
|
||||
"deliver_drop": deliver_drop,
|
||||
}
|
||||
|
||||
# Remove misleading tools that caused regressions in the logs.
|
||||
# (We can re-add later if needed.)
|
||||
sync_mcp_tools.pop("giveItemToSomeone", None)
|
||||
|
||||
rlm_tools: dict[str, Callable[..., Any]] = {
|
||||
**sync_mcp_tools,
|
||||
**sync_memory_tools,
|
||||
**helper_tools,
|
||||
}
|
||||
|
||||
rlm = MinecraftFriendProgram(MinecraftFriendConfig(tools=rlm_tools))
|
||||
|
||||
# Join once up-front so the bot is in-world.
|
||||
join_res = await session.call_tool(
|
||||
"joinGame",
|
||||
arguments={
|
||||
"username": args.bot_username,
|
||||
"host": args.host,
|
||||
"port": args.mc_port,
|
||||
},
|
||||
)
|
||||
console.print(
|
||||
Panel(_calltool_text(join_res), title="joinGame", border_style="green")
|
||||
)
|
||||
|
||||
# Greet once.
|
||||
await session.call_tool(
|
||||
"sendChat",
|
||||
arguments={
|
||||
"message": f"Hey! I’m {SETTINGS.persona_name}. I’m here—want to explore or build something?"
|
||||
},
|
||||
)
|
||||
state.last_spoke_at = time.time()
|
||||
|
||||
while True:
|
||||
read_res = await session.call_tool(
|
||||
"readChat",
|
||||
arguments={"count": 40, "filterType": "all", "timeLimit": 120},
|
||||
)
|
||||
summary = _calltool_text(read_res)
|
||||
lines = drop_own_messages(
|
||||
extract_chat_lines(summary), args.bot_username
|
||||
)
|
||||
|
||||
fp = fingerprint(lines)
|
||||
new_chat = fp != state.last_chat_fingerprint
|
||||
state.last_chat_fingerprint = fp
|
||||
|
||||
now = time.time()
|
||||
should_initiate = (
|
||||
now - state.last_spoke_at
|
||||
) > SETTINGS.idle_chitchat_seconds
|
||||
can_decide = (now - state.last_decide_at) > max(
|
||||
SETTINGS.poll_seconds, 4.0
|
||||
)
|
||||
|
||||
if (new_chat or should_initiate) and can_decide:
|
||||
state.last_decide_at = now
|
||||
chat_context = "\n".join(lines[-30:])
|
||||
memory = (
|
||||
"You have a persistent memory filesystem under `.memory/`.\n"
|
||||
"Use these tools to store/recall information:\n"
|
||||
"- mem_list_directory(path)\n"
|
||||
"- mem_read_text_file(path, head=None, tail=None)\n"
|
||||
"- mem_write_file(path, content)\n"
|
||||
"- mem_append_file(path, content)\n"
|
||||
"- mem_search_files(path='', pattern='*', contains=None, limit=50)\n"
|
||||
"- mem_directory_tree(path='', max_depth=6)\n"
|
||||
"- mem_get_file_info(path)\n"
|
||||
"\n"
|
||||
"Suggested files:\n"
|
||||
"- profile/paul.md (stable preferences)\n"
|
||||
"- world/status.md (current world + tasks)\n"
|
||||
"- notes/log.md (timestamped scratchpad)\n"
|
||||
"\n"
|
||||
"Gameplay facts (IMPORTANT):\n"
|
||||
"- To give items to Paul, prefer `deliver_drop(user_name, item_name, count)`.\n"
|
||||
"- `giveItemToSomeone` is unreliable here; do NOT use it.\n"
|
||||
"- To gather a stack, use `gather_to('oak_log', 64)` then `deliver_drop('pmlockett', 'oak_log', 64)`.\n"
|
||||
)
|
||||
|
||||
try:
|
||||
# Run RLM in a worker thread so sync tool calls can safely
|
||||
# schedule async MCP operations onto this running event loop.
|
||||
result = await asyncio.to_thread(
|
||||
rlm,
|
||||
chat=chat_context,
|
||||
memory=memory,
|
||||
)
|
||||
resp = getattr(result, "response", None)
|
||||
if resp:
|
||||
console.print(
|
||||
Panel(
|
||||
str(resp),
|
||||
title="RLM response",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
except RateLimitError as e:
|
||||
wait_s = _extract_retry_after_seconds(e) or 10.0
|
||||
console.print(
|
||||
Panel(
|
||||
f"Rate limited. Sleeping {wait_s:.1f}s.\n\n{e}",
|
||||
title="Rate limit",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(wait_s)
|
||||
|
||||
await asyncio.sleep(SETTINGS.poll_seconds)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
asyncio.run(main_async())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user