146 lines
5.4 KiB
Python
146 lines
5.4 KiB
Python
from __future__ import annotations
|
|
|
|
import builtins
|
|
import io
|
|
import sys
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from types import MappingProxyType
|
|
from typing import Any, Callable
|
|
|
|
from dspy.primitives.code_interpreter import CodeInterpreterError, FinalOutput
|
|
|
|
|
|
@dataclass
|
|
class UnsafeHostInterpreter:
|
|
"""
|
|
A minimal CodeInterpreter implementation that executes code in the host Python process.
|
|
|
|
Why this exists:
|
|
- DSPy's default RLM interpreter (Deno/Pyodide) currently relies on pyodide.ffi.run_sync
|
|
to bridge async tool calls, which fails on runtimes without WASM stack switching support.
|
|
|
|
Tradeoff:
|
|
- This is NOT a security sandbox. It will execute arbitrary Python code produced by the LLM.
|
|
Use only in trusted/local environments.
|
|
"""
|
|
|
|
tools: dict[str, Callable[..., str]] = field(default_factory=dict)
|
|
# If RLM injects this attribute, we can map SUBMIT() to output fields.
|
|
output_fields: list[dict] | None = None
|
|
_started: bool = False
|
|
_globals: dict[str, Any] = field(default_factory=dict)
|
|
|
|
def start(self) -> None:
|
|
if self._started:
|
|
return
|
|
# Start with a constrained global namespace. This is not a real sandbox.
|
|
self._globals = {
|
|
"__name__": "__rlm_host__",
|
|
"__builtins__": MappingProxyType(
|
|
{
|
|
# Allow common harmless builtins needed for analysis.
|
|
"print": builtins.print,
|
|
"len": builtins.len,
|
|
"type": builtins.type,
|
|
"range": builtins.range,
|
|
"reversed": builtins.reversed,
|
|
"min": builtins.min,
|
|
"max": builtins.max,
|
|
"sum": builtins.sum,
|
|
"sorted": builtins.sorted,
|
|
"enumerate": builtins.enumerate,
|
|
"str": builtins.str,
|
|
"int": builtins.int,
|
|
"float": builtins.float,
|
|
"bool": builtins.bool,
|
|
"dict": builtins.dict,
|
|
"list": builtins.list,
|
|
"set": builtins.set,
|
|
"tuple": builtins.tuple,
|
|
"abs": builtins.abs,
|
|
"all": builtins.all,
|
|
"any": builtins.any,
|
|
"zip": builtins.zip,
|
|
}
|
|
),
|
|
}
|
|
# Provide a few commonly-used stdlib modules without enabling arbitrary imports.
|
|
# (The host interpreter is already unsafe, but keeping imports closed reduces footguns.)
|
|
import json as _json
|
|
import math as _math
|
|
import re as _re
|
|
|
|
self._globals.update({"re": _re, "json": _json, "math": _math})
|
|
self._started = True
|
|
|
|
def execute(self, code: str, variables: dict[str, Any] | None = None) -> Any:
|
|
if not self._started:
|
|
self.start()
|
|
|
|
# Inject variables and tools into the exec namespace.
|
|
if variables:
|
|
self._globals.update(variables)
|
|
self._globals.update(self.tools)
|
|
|
|
# Provide SUBMIT for early termination.
|
|
class _SubmitSignal(BaseException):
|
|
def __init__(self, payload: dict[str, Any]):
|
|
super().__init__()
|
|
self.payload = payload
|
|
|
|
def SUBMIT(*args: Any, **kwargs: Any) -> None: # noqa: N802 - matches DSPy contract
|
|
# RLM expects interpreter.execute() to RETURN a FinalOutput instance,
|
|
# not raise it as an exception. We raise a private control-flow signal
|
|
# and convert it into FinalOutput below.
|
|
if not kwargs:
|
|
# Support SUBMIT("...") for single-output signatures.
|
|
if (
|
|
len(args) == 1
|
|
and self.output_fields
|
|
and len(self.output_fields) == 1
|
|
):
|
|
name = self.output_fields[0]["name"]
|
|
kwargs = {name: args[0]}
|
|
# Support SUBMIT() if user assigned output variables in globals.
|
|
elif len(args) == 0 and self.output_fields:
|
|
payload: dict[str, Any] = {}
|
|
for f in self.output_fields:
|
|
fname = f["name"]
|
|
if fname in self._globals:
|
|
payload[fname] = self._globals[fname]
|
|
if payload:
|
|
kwargs = payload
|
|
else:
|
|
raise _SubmitSignal(
|
|
{
|
|
"error": "SUBMIT called without outputs; provide kwargs or set output variables."
|
|
}
|
|
)
|
|
|
|
raise _SubmitSignal(kwargs)
|
|
|
|
self._globals["SUBMIT"] = SUBMIT
|
|
|
|
buf = io.StringIO()
|
|
old_stdout, old_stderr = sys.stdout, sys.stderr
|
|
sys.stdout, sys.stderr = buf, buf
|
|
try:
|
|
exec(code, self._globals, self._globals)
|
|
except _SubmitSignal as sig:
|
|
return FinalOutput(sig.payload)
|
|
except SyntaxError:
|
|
raise
|
|
except Exception as e:
|
|
tb = traceback.format_exc()
|
|
raise CodeInterpreterError(f"{e}\n\n{tb}")
|
|
finally:
|
|
sys.stdout, sys.stderr = old_stdout, old_stderr
|
|
|
|
out = buf.getvalue()
|
|
return out.strip() if out.strip() else None
|
|
|
|
def shutdown(self) -> None:
|
|
self._globals.clear()
|
|
self._started = False
|