144 lines
4.0 KiB
Python
144 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Literal
|
|
|
|
from .exceptions import CodexError
|
|
from .items import ThreadItem, parse_thread_item
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class Usage:
|
|
input_tokens: int
|
|
cached_input_tokens: int
|
|
output_tokens: int
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ThreadError:
|
|
message: str
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ThreadStartedEvent:
|
|
type: Literal["thread.started"] = field(default="thread.started", init=False)
|
|
thread_id: str
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TurnStartedEvent:
|
|
type: Literal["turn.started"] = field(default="turn.started", init=False)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TurnCompletedEvent:
|
|
type: Literal["turn.completed"] = field(default="turn.completed", init=False)
|
|
usage: Usage
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TurnFailedEvent:
|
|
type: Literal["turn.failed"] = field(default="turn.failed", init=False)
|
|
error: ThreadError
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ItemStartedEvent:
|
|
type: Literal["item.started"] = field(default="item.started", init=False)
|
|
item: ThreadItem
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ItemUpdatedEvent:
|
|
type: Literal["item.updated"] = field(default="item.updated", init=False)
|
|
item: ThreadItem
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ItemCompletedEvent:
|
|
type: Literal["item.completed"] = field(default="item.completed", init=False)
|
|
item: ThreadItem
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ThreadErrorEvent:
|
|
type: Literal["error"] = field(default="error", init=False)
|
|
message: str
|
|
|
|
|
|
ThreadEvent = (
|
|
ThreadStartedEvent
|
|
| TurnStartedEvent
|
|
| TurnCompletedEvent
|
|
| TurnFailedEvent
|
|
| ItemStartedEvent
|
|
| ItemUpdatedEvent
|
|
| ItemCompletedEvent
|
|
| ThreadErrorEvent
|
|
)
|
|
|
|
|
|
def _ensure_dict(payload: object) -> dict[str, object]:
|
|
if isinstance(payload, dict):
|
|
return payload
|
|
raise CodexError("Event payload must be an object")
|
|
|
|
|
|
def _ensure_str(value: object, field: str) -> str:
|
|
if isinstance(value, str):
|
|
return value
|
|
raise CodexError(f"Expected string for {field}")
|
|
|
|
|
|
def _ensure_int(value: object, field: str) -> int:
|
|
if isinstance(value, int):
|
|
return value
|
|
raise CodexError(f"Expected integer for {field}")
|
|
|
|
|
|
def _parse_usage(payload: object) -> Usage:
|
|
data = _ensure_dict(payload)
|
|
return Usage(
|
|
input_tokens=_ensure_int(data.get("input_tokens"), "input_tokens"),
|
|
cached_input_tokens=_ensure_int(
|
|
data.get("cached_input_tokens"), "cached_input_tokens"
|
|
),
|
|
output_tokens=_ensure_int(data.get("output_tokens"), "output_tokens"),
|
|
)
|
|
|
|
|
|
def parse_thread_event(payload: object) -> ThreadEvent:
|
|
data = _ensure_dict(payload)
|
|
type_name = _ensure_str(data.get("type"), "type")
|
|
|
|
if type_name == "thread.started":
|
|
thread_id = _ensure_str(data.get("thread_id"), "thread_id")
|
|
return ThreadStartedEvent(thread_id=thread_id)
|
|
|
|
if type_name == "turn.started":
|
|
return TurnStartedEvent()
|
|
|
|
if type_name == "turn.completed":
|
|
usage = _parse_usage(data.get("usage"))
|
|
return TurnCompletedEvent(usage=usage)
|
|
|
|
if type_name == "turn.failed":
|
|
error_payload = _ensure_dict(data.get("error"))
|
|
message = _ensure_str(error_payload.get("message"), "error.message")
|
|
return TurnFailedEvent(error=ThreadError(message=message))
|
|
|
|
if type_name in {"item.started", "item.updated", "item.completed"}:
|
|
item_payload = data.get("item")
|
|
item = parse_thread_item(item_payload)
|
|
if type_name == "item.started":
|
|
return ItemStartedEvent(item=item)
|
|
if type_name == "item.updated":
|
|
return ItemUpdatedEvent(item=item)
|
|
return ItemCompletedEvent(item=item)
|
|
|
|
if type_name == "error":
|
|
message = _ensure_str(data.get("message"), "message")
|
|
return ThreadErrorEvent(message=message)
|
|
|
|
raise CodexError(f"Unsupported event type: {type_name}")
|