(no commit message)

This commit is contained in:
2025-11-08 16:29:47 -05:00
parent a78efdbd0f
commit f511aac5a3
21 changed files with 1806 additions and 20 deletions

3
src/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .index import *
__all__ = ["CodexAgent"]

83
src/codex/__init__.py Normal file
View File

@@ -0,0 +1,83 @@
from __future__ import annotations
from .client import Codex
from .config import CodexOptions, SandboxMode, ThreadOptions, TurnOptions, ApprovalMode
from .events import (
ThreadEvent,
ThreadStartedEvent,
TurnStartedEvent,
TurnCompletedEvent,
TurnFailedEvent,
ItemStartedEvent,
ItemUpdatedEvent,
ItemCompletedEvent,
ThreadErrorEvent,
Usage,
)
from .items import (
ThreadItem,
AgentMessageItem,
ReasoningItem,
CommandExecutionItem,
CommandExecutionStatus,
FileChangeItem,
PatchApplyStatus,
PatchChangeKind,
McpToolCallItem,
McpToolCallStatus,
WebSearchItem,
TodoListItem,
ErrorItem,
)
from .thread import Thread, ThreadRunResult, ThreadStream
from .exceptions import (
CodexError,
UnsupportedPlatformError,
SpawnError,
ExecExitError,
JsonParseError,
ThreadRunError,
SchemaValidationError,
)
__all__ = [
"Codex",
"CodexOptions",
"ThreadOptions",
"TurnOptions",
"SandboxMode",
"ApprovalMode",
"Thread",
"ThreadRunResult",
"ThreadStream",
"ThreadEvent",
"ThreadStartedEvent",
"TurnStartedEvent",
"TurnCompletedEvent",
"TurnFailedEvent",
"ItemStartedEvent",
"ItemUpdatedEvent",
"ItemCompletedEvent",
"ThreadErrorEvent",
"Usage",
"ThreadItem",
"AgentMessageItem",
"ReasoningItem",
"CommandExecutionItem",
"CommandExecutionStatus",
"FileChangeItem",
"PatchApplyStatus",
"PatchChangeKind",
"McpToolCallItem",
"McpToolCallStatus",
"WebSearchItem",
"TodoListItem",
"ErrorItem",
"CodexError",
"UnsupportedPlatformError",
"SpawnError",
"ExecExitError",
"JsonParseError",
"ThreadRunError",
"SchemaValidationError",
]

22
src/codex/client.py Normal file
View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from typing import Optional
from .config import CodexOptions, ThreadOptions
from .exec import CodexExec
from .thread import Thread
class Codex:
def __init__(self, options: Optional[CodexOptions] = None) -> None:
opts = options or CodexOptions()
self._options = opts
self._exec = CodexExec(opts.codex_path_override)
def start_thread(self, options: Optional[ThreadOptions] = None) -> Thread:
thread_options = options or ThreadOptions()
return Thread(self._exec, self._options, thread_options)
def resume_thread(self, thread_id: str, options: Optional[ThreadOptions] = None) -> Thread:
thread_options = options or ThreadOptions()
return Thread(self._exec, self._options, thread_options, thread_id)

44
src/codex/config.py Normal file
View File

@@ -0,0 +1,44 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
from typing import Mapping, Optional, TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover - typing only
from pydantic import BaseModel as PydanticBaseModel
SchemaInput = Mapping[str, object] | type[PydanticBaseModel] | PydanticBaseModel
else:
SchemaInput = Mapping[str, object]
class ApprovalMode(StrEnum):
NEVER = "never"
ON_REQUEST = "on-request"
ON_FAILURE = "on-failure"
UNTRUSTED = "untrusted"
class SandboxMode(StrEnum):
READ_ONLY = "read-only"
WORKSPACE_WRITE = "workspace-write"
DANGER_FULL_ACCESS = "danger-full-access"
@dataclass(frozen=True, slots=True)
class CodexOptions:
codex_path_override: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
@dataclass(frozen=True, slots=True)
class ThreadOptions:
model: Optional[str] = None
sandbox_mode: Optional[SandboxMode] = None
working_directory: Optional[str] = None
skip_git_repo_check: bool = False
@dataclass(frozen=True, slots=True)
class TurnOptions:
output_schema: Optional[SchemaInput] = None

42
src/codex/discovery.py Normal file
View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import platform
import sys
from pathlib import Path
from .exceptions import UnsupportedPlatformError
def _detect_target() -> str:
system = sys.platform
machine = platform.machine().lower()
if system in {"linux", "linux2"}:
if machine in {"x86_64", "amd64"}:
return "x86_64-unknown-linux-musl"
if machine in {"aarch64", "arm64"}:
return "aarch64-unknown-linux-musl"
elif system == "darwin":
if machine == "x86_64":
return "x86_64-apple-darwin"
if machine in {"arm64", "aarch64"}:
return "aarch64-apple-darwin"
elif system == "win32":
if machine in {"x86_64", "amd64"}:
return "x86_64-pc-windows-msvc"
if machine in {"arm64", "aarch64"}:
return "aarch64-pc-windows-msvc"
raise UnsupportedPlatformError(system, machine)
def find_codex_binary(override: str | None = None) -> Path:
if override:
return Path(override)
target = _detect_target()
package_root = Path(__file__).resolve().parent
vendor_root = package_root / "vendor" / target / "codex"
binary_name = "codex.exe" if sys.platform == "win32" else "codex"
binary_path = vendor_root / binary_name
return binary_path

141
src/codex/events.py Normal file
View File

@@ -0,0 +1,141 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal
from .exceptions import CodexError
from .items import ThreadItem, parse_thread_item
@dataclass(frozen=True, slots=True)
class Usage:
input_tokens: int
cached_input_tokens: int
output_tokens: int
@dataclass(frozen=True, slots=True)
class ThreadError:
message: str
@dataclass(frozen=True, slots=True)
class ThreadStartedEvent:
type: Literal["thread.started"] = field(default="thread.started", init=False)
thread_id: str
@dataclass(frozen=True, slots=True)
class TurnStartedEvent:
type: Literal["turn.started"] = field(default="turn.started", init=False)
@dataclass(frozen=True, slots=True)
class TurnCompletedEvent:
type: Literal["turn.completed"] = field(default="turn.completed", init=False)
usage: Usage
@dataclass(frozen=True, slots=True)
class TurnFailedEvent:
type: Literal["turn.failed"] = field(default="turn.failed", init=False)
error: ThreadError
@dataclass(frozen=True, slots=True)
class ItemStartedEvent:
type: Literal["item.started"] = field(default="item.started", init=False)
item: ThreadItem
@dataclass(frozen=True, slots=True)
class ItemUpdatedEvent:
type: Literal["item.updated"] = field(default="item.updated", init=False)
item: ThreadItem
@dataclass(frozen=True, slots=True)
class ItemCompletedEvent:
type: Literal["item.completed"] = field(default="item.completed", init=False)
item: ThreadItem
@dataclass(frozen=True, slots=True)
class ThreadErrorEvent:
type: Literal["error"] = field(default="error", init=False)
message: str
ThreadEvent = (
ThreadStartedEvent
| TurnStartedEvent
| TurnCompletedEvent
| TurnFailedEvent
| ItemStartedEvent
| ItemUpdatedEvent
| ItemCompletedEvent
| ThreadErrorEvent
)
def _ensure_dict(payload: object) -> dict[str, object]:
if isinstance(payload, dict):
return payload
raise CodexError("Event payload must be an object")
def _ensure_str(value: object, field: str) -> str:
if isinstance(value, str):
return value
raise CodexError(f"Expected string for {field}")
def _ensure_int(value: object, field: str) -> int:
if isinstance(value, int):
return value
raise CodexError(f"Expected integer for {field}")
def _parse_usage(payload: object) -> Usage:
data = _ensure_dict(payload)
return Usage(
input_tokens=_ensure_int(data.get("input_tokens"), "input_tokens"),
cached_input_tokens=_ensure_int(data.get("cached_input_tokens"), "cached_input_tokens"),
output_tokens=_ensure_int(data.get("output_tokens"), "output_tokens"),
)
def parse_thread_event(payload: object) -> ThreadEvent:
data = _ensure_dict(payload)
type_name = _ensure_str(data.get("type"), "type")
if type_name == "thread.started":
thread_id = _ensure_str(data.get("thread_id"), "thread_id")
return ThreadStartedEvent(thread_id=thread_id)
if type_name == "turn.started":
return TurnStartedEvent()
if type_name == "turn.completed":
usage = _parse_usage(data.get("usage"))
return TurnCompletedEvent(usage=usage)
if type_name == "turn.failed":
error_payload = _ensure_dict(data.get("error"))
message = _ensure_str(error_payload.get("message"), "error.message")
return TurnFailedEvent(error=ThreadError(message=message))
if type_name in {"item.started", "item.updated", "item.completed"}:
item_payload = data.get("item")
item = parse_thread_item(item_payload)
if type_name == "item.started":
return ItemStartedEvent(item=item)
if type_name == "item.updated":
return ItemUpdatedEvent(item=item)
return ItemCompletedEvent(item=item)
if type_name == "error":
message = _ensure_str(data.get("message"), "message")
return ThreadErrorEvent(message=message)
raise CodexError(f"Unsupported event type: {type_name}")

63
src/codex/exceptions.py Normal file
View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
class CodexError(Exception):
"""Base exception for Codex SDK."""
def _format_command(command: Sequence[str] | None) -> str:
if not command:
return "<unknown>"
return " ".join(command)
class UnsupportedPlatformError(CodexError):
def __init__(self, platform: str, machine: str) -> None:
message = f"Unsupported platform: {platform} ({machine})"
super().__init__(message)
self.platform = platform
self.machine = machine
class SpawnError(CodexError):
def __init__(self, command: Sequence[str] | None, error: OSError) -> None:
self.command = list(command) if command else None
self.original_error = error
super().__init__(f"Failed to spawn codex exec: {_format_command(self.command)}: {error}")
@dataclass(slots=True)
class ExecExitError(CodexError):
command: tuple[str, ...]
exit_code: int
stderr: str
def __str__(self) -> str: # pragma: no cover - trivial formatting
stderr = self.stderr.strip()
tail = f": {stderr}" if stderr else ""
return f"codex exec exited with code {self.exit_code}{tail}"
@dataclass(slots=True)
class JsonParseError(CodexError):
raw_line: str
command: tuple[str, ...]
def __str__(self) -> str: # pragma: no cover - trivial formatting
sample = self.raw_line
if len(sample) > 200:
sample = sample[:197] + "..."
return f"Failed to parse codex event: {sample}"
class ThreadRunError(CodexError):
def __init__(self, message: str) -> None:
super().__init__(message)
class SchemaValidationError(CodexError):
def __init__(self, message: str) -> None:
super().__init__(message)

132
src/codex/exec.py Normal file
View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import io
import os
import subprocess
from dataclasses import dataclass
from threading import Thread
from typing import Iterator, Optional
from .config import SandboxMode
from .discovery import find_codex_binary
from .exceptions import ExecExitError, SpawnError
INTERNAL_ORIGINATOR_ENV = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE"
PYTHON_SDK_ORIGINATOR = "codex_sdk_py"
@dataclass(frozen=True, slots=True)
class ExecArgs:
input: str
base_url: Optional[str] = None
api_key: Optional[str] = None
thread_id: Optional[str] = None
model: Optional[str] = None
sandbox_mode: Optional[SandboxMode] = None
working_directory: Optional[str] = None
skip_git_repo_check: bool = False
output_schema_path: Optional[str] = None
class CodexExec:
def __init__(self, executable_override: Optional[str] = None) -> None:
self._binary = find_codex_binary(executable_override)
def build_command(self, args: ExecArgs) -> list[str]:
command = [str(self._binary), "exec", "--experimental-json"]
if args.model:
command.extend(["--model", args.model])
if args.sandbox_mode:
command.extend(["--sandbox", args.sandbox_mode.value])
if args.working_directory:
command.extend(["--cd", args.working_directory])
if args.skip_git_repo_check:
command.append("--skip-git-repo-check")
if args.output_schema_path:
command.extend(["--output-schema", args.output_schema_path])
if args.thread_id:
command.extend(["resume", args.thread_id])
return command
def run_lines(self, args: ExecArgs) -> Iterator[str]:
command = self.build_command(args)
env = os.environ.copy()
env.setdefault(INTERNAL_ORIGINATOR_ENV, PYTHON_SDK_ORIGINATOR)
if args.base_url:
env["OPENAI_BASE_URL"] = args.base_url
if args.api_key:
env["CODEX_API_KEY"] = args.api_key
stderr_buffer: list[str] = []
try:
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
errors="strict",
env=env,
)
except OSError as error: # pragma: no cover - exercised indirectly
raise SpawnError(command, error) from error
if not process.stdin or not process.stdout:
process.kill()
raise SpawnError(command, OSError("Missing stdio pipes"))
stderr_thread: Thread | None = None
if process.stderr:
def _drain_stderr(pipe: io.TextIOBase, buffer: list[str]) -> None:
while True:
try:
chunk = pipe.readline()
except ValueError:
break
if chunk == "":
break
buffer.append(chunk)
stderr_thread = Thread(
target=_drain_stderr,
args=(process.stderr, stderr_buffer),
daemon=True,
)
stderr_thread.start()
try:
process.stdin.write(args.input)
process.stdin.close()
for line in iter(process.stdout.readline, ""):
yield line.rstrip("\n")
return_code = process.wait()
if stderr_thread is not None:
stderr_thread.join()
stderr_output = "".join(stderr_buffer)
if return_code != 0:
raise ExecExitError(tuple(command), return_code, stderr_output)
finally:
if process.stdout and not process.stdout.closed:
process.stdout.close()
if process.stderr and not process.stderr.closed:
try:
process.stderr.close()
except ValueError:
pass
if stderr_thread is not None and stderr_thread.is_alive():
stderr_thread.join(timeout=0.1)
returncode = process.poll()
if returncode is None:
process.kill()
try:
process.wait(timeout=0.5)
except subprocess.TimeoutExpired:
process.wait()

228
src/codex/items.py Normal file
View File

@@ -0,0 +1,228 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Iterable, Literal, Sequence, cast
from .exceptions import CodexError
class CommandExecutionStatus(StrEnum):
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
class PatchChangeKind(StrEnum):
ADD = "add"
DELETE = "delete"
UPDATE = "update"
class PatchApplyStatus(StrEnum):
COMPLETED = "completed"
FAILED = "failed"
class McpToolCallStatus(StrEnum):
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
@dataclass(frozen=True, slots=True)
class CommandExecutionItem:
type: Literal["command_execution"] = field(default="command_execution", init=False)
id: str
command: str
aggregated_output: str
status: CommandExecutionStatus
exit_code: int | None = None
@dataclass(frozen=True, slots=True)
class FileUpdateChange:
path: str
kind: PatchChangeKind
@dataclass(frozen=True, slots=True)
class FileChangeItem:
type: Literal["file_change"] = field(default="file_change", init=False)
id: str
changes: Sequence[FileUpdateChange]
status: PatchApplyStatus
@dataclass(frozen=True, slots=True)
class McpToolCallItem:
type: Literal["mcp_tool_call"] = field(default="mcp_tool_call", init=False)
id: str
server: str
tool: str
status: McpToolCallStatus
@dataclass(frozen=True, slots=True)
class AgentMessageItem:
type: Literal["agent_message"] = field(default="agent_message", init=False)
id: str
text: str
@dataclass(frozen=True, slots=True)
class ReasoningItem:
type: Literal["reasoning"] = field(default="reasoning", init=False)
id: str
text: str
@dataclass(frozen=True, slots=True)
class WebSearchItem:
type: Literal["web_search"] = field(default="web_search", init=False)
id: str
query: str
@dataclass(frozen=True, slots=True)
class ErrorItem:
type: Literal["error"] = field(default="error", init=False)
id: str
message: str
@dataclass(frozen=True, slots=True)
class TodoItem:
text: str
completed: bool
@dataclass(frozen=True, slots=True)
class TodoListItem:
type: Literal["todo_list"] = field(default="todo_list", init=False)
id: str
items: Sequence[TodoItem]
ThreadItem = (
AgentMessageItem
| ReasoningItem
| CommandExecutionItem
| FileChangeItem
| McpToolCallItem
| WebSearchItem
| TodoListItem
| ErrorItem
)
def _ensure_str(value: object, field: str) -> str:
if isinstance(value, str):
return value
raise CodexError(f"Expected string for {field}")
def _ensure_sequence(value: object, field: str) -> Sequence[object]:
if isinstance(value, Sequence) and not isinstance(value, (str, bytes)):
return cast(Sequence[object], value)
raise CodexError(f"Expected sequence for {field}")
def _parse_changes(values: Iterable[object]) -> list[FileUpdateChange]:
changes: list[FileUpdateChange] = []
for value in values:
if not isinstance(value, dict):
raise CodexError("Invalid file change entry")
path = _ensure_str(value.get("path"), "path")
kind = _ensure_str(value.get("kind"), "kind")
try:
enum_kind = PatchChangeKind(kind)
except ValueError as exc:
raise CodexError(f"Unsupported file change kind: {kind}") from exc
changes.append(FileUpdateChange(path=path, kind=enum_kind))
return changes
def _parse_todos(values: Iterable[object]) -> list[TodoItem]:
todos: list[TodoItem] = []
for value in values:
if not isinstance(value, dict):
raise CodexError("Invalid todo entry")
text = _ensure_str(value.get("text"), "text")
completed = bool(value.get("completed", False))
todos.append(TodoItem(text=text, completed=completed))
return todos
def parse_thread_item(payload: object) -> ThreadItem:
if not isinstance(payload, dict):
raise CodexError("Thread item must be an object")
type_name = _ensure_str(payload.get("type"), "type")
item_id = _ensure_str(payload.get("id"), "id")
if type_name == "agent_message":
text = _ensure_str(payload.get("text"), "text")
return AgentMessageItem(id=item_id, text=text)
if type_name == "reasoning":
text = _ensure_str(payload.get("text"), "text")
return ReasoningItem(id=item_id, text=text)
if type_name == "command_execution":
command = _ensure_str(payload.get("command"), "command")
aggregated_output = _ensure_str(payload.get("aggregated_output"), "aggregated_output")
status_str = _ensure_str(payload.get("status"), "status")
try:
status = CommandExecutionStatus(status_str)
except ValueError as exc:
raise CodexError(f"Unsupported command execution status: {status_str}") from exc
exit_code = payload.get("exit_code")
exit_value = int(exit_code) if isinstance(exit_code, int) else None
return CommandExecutionItem(
id=item_id,
command=command,
aggregated_output=aggregated_output,
status=status,
exit_code=exit_value,
)
if type_name == "file_change":
changes_raw = _ensure_sequence(payload.get("changes"), "changes")
status_str = _ensure_str(payload.get("status"), "status")
try:
change_status = PatchApplyStatus(status_str)
except ValueError as exc:
raise CodexError(f"Unsupported file change status: {status_str}") from exc
changes = _parse_changes(changes_raw)
return FileChangeItem(id=item_id, changes=changes, status=change_status)
if type_name == "mcp_tool_call":
server = _ensure_str(payload.get("server"), "server")
tool = _ensure_str(payload.get("tool"), "tool")
status_str = _ensure_str(payload.get("status"), "status")
try:
call_status = McpToolCallStatus(status_str)
except ValueError as exc:
raise CodexError(f"Unsupported MCP tool call status: {status_str}") from exc
return McpToolCallItem(
id=item_id,
server=server,
tool=tool,
status=call_status,
)
if type_name == "web_search":
query = _ensure_str(payload.get("query"), "query")
return WebSearchItem(id=item_id, query=query)
if type_name == "error":
message = _ensure_str(payload.get("message"), "message")
return ErrorItem(id=item_id, message=message)
if type_name == "todo_list":
todos_raw = _ensure_sequence(payload.get("items"), "items")
todos = _parse_todos(todos_raw)
return TodoListItem(id=item_id, items=todos)
raise CodexError(f"Unsupported item type: {type_name}")

89
src/codex/schema.py Normal file
View File

@@ -0,0 +1,89 @@
from __future__ import annotations
import json
import tempfile
from collections.abc import Mapping
from pathlib import Path
from types import TracebackType
from typing import Any, Type, cast
from functools import lru_cache
from .exceptions import SchemaValidationError
from .config import SchemaInput
@lru_cache(maxsize=1)
def _get_pydantic_base_model() -> Type[Any] | None: # pragma: no cover - import guard
try:
from pydantic import BaseModel
except ImportError:
return None
return cast(Type[Any], BaseModel)
def _is_pydantic_model(value: object) -> bool:
base_model = _get_pydantic_base_model()
return isinstance(value, type) and base_model is not None and issubclass(value, base_model)
def _is_pydantic_instance(value: object) -> bool:
base_model = _get_pydantic_base_model()
return base_model is not None and isinstance(value, base_model)
def _convert_schema_input(schema: SchemaInput | None) -> Mapping[str, object] | None:
if schema is None or isinstance(schema, Mapping):
return schema
if _is_pydantic_model(schema):
return cast(Mapping[str, object], schema.model_json_schema())
if _is_pydantic_instance(schema):
return cast(Mapping[str, object], schema.model_json_schema())
raise SchemaValidationError(
"output_schema must be a mapping or a Pydantic BaseModel (class or instance)",
)
class SchemaTempFile:
def __init__(self, schema: SchemaInput | None) -> None:
self._raw_schema = schema
self._temp_dir: tempfile.TemporaryDirectory[str] | None = None
self.path: Path | None = None
def __enter__(self) -> SchemaTempFile:
schema = _convert_schema_input(self._raw_schema)
if schema is None:
return self
for key in schema.keys():
if not isinstance(key, str):
raise SchemaValidationError("output_schema keys must be strings")
self._temp_dir = tempfile.TemporaryDirectory(prefix="codex-output-schema-")
schema_dir = Path(self._temp_dir.name)
schema_path = schema_dir / "schema.json"
with schema_path.open("w", encoding="utf-8") as handle:
json.dump(schema, handle, ensure_ascii=False)
self.path = schema_path
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
self.cleanup()
def cleanup(self) -> None:
if self._temp_dir is not None:
self._temp_dir.cleanup()
self._temp_dir = None
self.path = None
def prepare_schema_file(schema: SchemaInput | None) -> SchemaTempFile:
return SchemaTempFile(schema)

113
src/codex/thread.py Normal file
View File

@@ -0,0 +1,113 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Iterator, Optional
from .config import CodexOptions, ThreadOptions, TurnOptions
from .events import (
ItemCompletedEvent,
ThreadErrorEvent,
ThreadEvent,
ThreadStartedEvent,
TurnCompletedEvent,
TurnFailedEvent,
Usage,
parse_thread_event,
)
from .exceptions import JsonParseError, ThreadRunError
from .exec import CodexExec, ExecArgs
from .items import AgentMessageItem, ThreadItem
from .schema import prepare_schema_file
@dataclass(frozen=True, slots=True)
class ThreadRunResult:
items: list[ThreadItem]
final_response: str
usage: Optional[Usage]
@dataclass(frozen=True, slots=True)
class ThreadStream:
events: Iterator[ThreadEvent]
def __iter__(self) -> Iterator[ThreadEvent]:
return self.events
class Thread:
def __init__(
self,
exec_client: CodexExec,
codex_options: CodexOptions,
thread_options: ThreadOptions,
thread_id: Optional[str] = None,
) -> None:
self._exec = exec_client
self._codex_options = codex_options
self._thread_options = thread_options
self._id = thread_id
@property
def id(self) -> Optional[str]:
return self._id
def run_streamed(self, prompt: str, turn_options: Optional[TurnOptions] = None) -> ThreadStream:
events = self._stream_events(prompt, turn_options)
return ThreadStream(events=events)
def run(self, prompt: str, turn_options: Optional[TurnOptions] = None) -> ThreadRunResult:
final_response = ""
items: list[ThreadItem] = []
usage: Optional[Usage] = None
failure_message: Optional[str] = None
for event in self._stream_events(prompt, turn_options):
if isinstance(event, ThreadErrorEvent):
raise ThreadRunError(event.message)
if isinstance(event, TurnFailedEvent):
failure_message = event.error.message
break
if isinstance(event, TurnCompletedEvent):
usage = event.usage
if isinstance(event, ItemCompletedEvent):
item = event.item
items.append(item)
if isinstance(item, AgentMessageItem):
final_response = item.text
if failure_message is not None:
raise ThreadRunError(failure_message)
return ThreadRunResult(items=items, final_response=final_response, usage=usage)
def _stream_events(
self,
prompt: str,
turn_options: Optional[TurnOptions],
) -> Iterator[ThreadEvent]:
turn = turn_options or TurnOptions()
with prepare_schema_file(turn.output_schema) as schema_file:
exec_args = ExecArgs(
input=prompt,
base_url=self._codex_options.base_url,
api_key=self._codex_options.api_key,
thread_id=self._id,
model=self._thread_options.model,
sandbox_mode=self._thread_options.sandbox_mode,
working_directory=self._thread_options.working_directory,
skip_git_repo_check=self._thread_options.skip_git_repo_check,
output_schema_path=str(schema_file.path) if schema_file.path else None,
)
command = tuple(self._exec.build_command(exec_args))
for line in self._exec.run_lines(exec_args):
try:
payload = json.loads(line)
except json.JSONDecodeError as error:
raise JsonParseError(line, command) from error
event = parse_thread_event(payload)
if isinstance(event, ThreadStartedEvent):
self._id = event.thread_id
yield event

View File

@@ -0,0 +1,10 @@
"""CodexModule - DSPy module for OpenAI Codex SDK.
This package provides a signature-driven interface to the Codex agent SDK,
enabling stateful agentic workflows through DSPy signatures.
"""
from .agent import CodexModule
__all__ = ["CodexModule"]
__version__ = "0.1.0"

201
src/codex_dspy/agent.py Normal file
View File

@@ -0,0 +1,201 @@
"""CodexAgent - DSPy module wrapping OpenAI Codex SDK.
This module provides a signature-driven interface to the Codex agent SDK.
Each CodexAgent instance maintains a stateful thread that accumulates context
across multiple forward() calls.
"""
from typing import Any, Optional, Union, get_args, get_origin
from pydantic import BaseModel
import dspy
from dspy.primitives.prediction import Prediction
from dspy.signatures.signature import Signature, ensure_signature
from ..codex import Codex, CodexOptions, SandboxMode, ThreadOptions, TurnOptions
def _is_str_type(annotation: Any) -> bool:
"""Check if annotation is str or Optional[str].
Args:
annotation: Type annotation to check
Returns:
True if annotation is str, Optional[str], or Union[str, None]
"""
if annotation == str:
return True
origin = get_origin(annotation)
if origin is Union:
args = get_args(annotation)
# Check for Optional[str] which is Union[str, None]
if len(args) == 2 and str in args and type(None) in args:
return True
return False
class CodexModule(dspy.Module):
"""DSPy module for Codex SDK integration.
Creates a stateful agent where each instance maintains one conversation thread.
Multiple forward() calls on the same instance continue the same conversation.
Args:
signature: DSPy signature (must have exactly 1 input and 1 output field)
working_directory: Directory where Codex agent will execute commands
model: Model to use (e.g., "gpt-4", "gpt-4-turbo"). Defaults to Codex default.
sandbox_mode: Execution sandbox level (READ_ONLY, WORKSPACE_WRITE, DANGER_FULL_ACCESS)
skip_git_repo_check: Allow non-git directories as working_directory
api_key: OpenAI API key (falls back to CODEX_API_KEY env var)
base_url: API base URL (falls back to OPENAI_BASE_URL env var)
codex_path_override: Override path to codex binary (for testing)
Example:
>>> sig = dspy.Signature('message:str -> answer:str')
>>> agent = CodexAgent(sig, working_directory=".")
>>> result = agent(message="What files are in this directory?")
>>> print(result.answer) # str response
>>> print(result.trace) # list of items (commands, files, etc.)
>>> print(result.usage) # token counts
Example with Pydantic output:
>>> class BugReport(BaseModel):
... severity: str
... description: str
>>> sig = dspy.Signature('message:str -> report:BugReport')
>>> agent = CodexAgent(sig, working_directory=".")
>>> result = agent(message="Analyze the bug")
>>> print(result.report.severity) # typed access
"""
def __init__(
self,
signature: str | type[Signature],
working_directory: str,
model: Optional[str] = None,
sandbox_mode: Optional[SandboxMode] = None,
skip_git_repo_check: bool = False,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
codex_path_override: Optional[str] = None,
):
super().__init__()
# Ensure signature is valid
self.signature = ensure_signature(signature)
# Validate: exactly 1 input field, 1 output field
if len(self.signature.input_fields) != 1:
input_fields = list(self.signature.input_fields.keys())
raise ValueError(
f"CodexAgent requires exactly 1 input field, got {len(input_fields)}: {input_fields}\n"
f"Example: dspy.Signature('message:str -> answer:str')"
)
if len(self.signature.output_fields) != 1:
output_fields = list(self.signature.output_fields.keys())
raise ValueError(
f"CodexAgent requires exactly 1 output field, got {len(output_fields)}: {output_fields}\n"
f"Example: dspy.Signature('message:str -> answer:str')"
)
# Extract field names and types
self.input_field = list(self.signature.input_fields.keys())[0]
self.output_field = list(self.signature.output_fields.keys())[0]
self.output_field_info = self.signature.output_fields[self.output_field]
self.output_type = self.output_field_info.annotation
# Create Codex client
self.client = Codex(
options=CodexOptions(
api_key=api_key,
base_url=base_url,
codex_path_override=codex_path_override,
)
)
# Start thread (1 agent instance = 1 stateful thread)
self.thread = self.client.start_thread(
options=ThreadOptions(
working_directory=working_directory,
model=model,
sandbox_mode=sandbox_mode,
skip_git_repo_check=skip_git_repo_check,
)
)
def forward(self, **kwargs) -> Prediction:
"""Execute agent with input message.
Args:
**kwargs: Must contain the input field specified in signature
Returns:
Prediction with:
- Typed output field (name from signature)
- trace: list[ThreadItem] - chronological items (commands, files, etc.)
- usage: Usage - token counts (input_tokens, cached_input_tokens, output_tokens)
Raises:
ValueError: If Pydantic parsing fails for typed output
"""
# 1. Extract input message
message = kwargs[self.input_field]
# 2. Append output field description if present (skip DSPy's default ${field_name} placeholder)
output_desc = (self.output_field_info.json_schema_extra or {}).get("desc")
# Skip if desc is just DSPy's default placeholder (e.g., "${answer}" for field named "answer")
if output_desc and output_desc != f"${{{self.output_field}}}":
message = f"{message}\n\nPlease produce the following output: {output_desc}"
# 3. Build TurnOptions if output type is not str
turn_options = None
if not _is_str_type(self.output_type):
# Get Pydantic JSON schema and ensure additionalProperties is false
schema = self.output_type.model_json_schema()
if "additionalProperties" not in schema:
schema["additionalProperties"] = False
turn_options = TurnOptions(output_schema=schema)
# 4. Call Codex SDK
result = self.thread.run(message, turn_options)
# 5. Parse response
parsed_output = result.final_response
if not _is_str_type(self.output_type):
# Parse as Pydantic model
try:
parsed_output = self.output_type.model_validate_json(result.final_response)
except Exception as e:
# Provide helpful error with response preview
response_preview = result.final_response[:500]
if len(result.final_response) > 500:
response_preview += "..."
raise ValueError(
f"Failed to parse Codex response as {self.output_type.__name__}: {e}\n"
f"Response: {response_preview}"
) from e
# 6. Return Prediction with typed output + trace + usage
return Prediction(
**{self.output_field: parsed_output},
trace=result.items,
usage=result.usage,
)
@property
def thread_id(self) -> Optional[str]:
"""Get thread ID for this agent instance.
The thread ID is assigned after the first forward() call.
Useful for debugging and visibility into the conversation state.
Returns:
Thread ID string, or None if no forward() calls have been made yet
"""
return self.thread.id

37
src/index.py Normal file
View File

@@ -0,0 +1,37 @@
import dspy
from modaic import PrecompiledAgent, PrecompiledConfig
from .codex_dspy import CodexModule
from .codex import Codex, CodexOptions, SandboxMode, ThreadOptions, TurnOptions
from dspy.signatures.signature import Signature
from dspy.primitives.prediction import Prediction
from typing import Optional
class CodexAgentConfig(PrecompiledConfig):
signature: str | type[Signature] = "message:str -> answer:str"
working_directory: str = ""
model: Optional[str] = None
sandbox_mode: Optional[SandboxMode] = None
skip_git_repo_check: bool = False
api_key: Optional[str] = None
base_url: Optional[str] = None
codex_path_override: Optional[str] = None
class CodexAgent(PrecompiledAgent):
config : CodexAgentConfig
def __init__(self, config: CodexAgentConfig, **kwargs):
super().__init__(config, **kwargs)
self.codex_module = CodexModule(
signature=config.signature,
working_directory=config.working_directory,
model=config.model,
sandbox_mode=config.sandbox_mode,
skip_git_repo_check=config.skip_git_repo_check,
api_key=config.api_key,
base_url=config.base_url,
codex_path_override=config.codex_path_override
)
def forward(self, **kwargs) -> Prediction:
return self.codex_module.forward(**kwargs)