Files
cancer-pipeline/util/predictiondump.py
2025-11-30 16:46:59 -05:00

242 lines
7.6 KiB
Python

"""
predictiondump.py
~~~~~~~~~~~~~~~~~~~~~~
This module provides functions to convert model predictions into a structured format
that can be easily serialized and saved for later analysis.
Copyright 2025, Kai-Po Chang at Med NLP Lab, China Medical University, with aid from chatGPT.
"""
__version__ = "0.1.0"
__date__ = "2025-10-05"
__author__ = ["Kai-Po Chang"]
__copyright__ = "Copyright 2025, Med NLP Lab, China Medical University"
__license__ = "MIT"
import json
from typing import Any, Iterable, Callable, Optional
from dataclasses import is_dataclass, asdict
from datetime import datetime, date
from decimal import Decimal
from collections.abc import Mapping
# --- helpers to make everything JSON-safe ---
def _to_json_safe(x: Any):
if isinstance(x, (datetime, date)):
return x.isoformat()
if isinstance(x, Decimal):
return float(x)
try:
import numpy as np # optional
if isinstance(x, (np.integer,)):
return int(x)
if isinstance(x, (np.floating,)):
return float(x)
if isinstance(x, (np.ndarray,)):
return x.tolist()
except Exception:
pass
return x
# --- core recursive dumper ---
def dump_prediction_plain(pred) -> dict[str, Any]:
"""Recursively convert a DSPy Prediction into a plain dict."""
from dspy.primitives.prediction import Prediction
def to_plain(obj: Any):
# Base primitives (return as-is)
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
# Nested Prediction → dict
if isinstance(obj, Prediction):
return {
k: to_plain(v)
for k, v in obj._store.items()
if not k.startswith(("_lm_usage", "_inputs", "_completions"))
}
# Mappings (dict-like)
if isinstance(obj, Mapping):
return {
k: to_plain(v)
for k, v in obj.items()
if not (isinstance(k, str) and k.startswith("_"))
}
# Lists / tuples
if isinstance(obj, (list, tuple, set)):
return [to_plain(v) for v in obj]
# Objects with .dict() or .to_dict()
for attr in ("dict", "to_dict"):
if hasattr(obj, attr) and callable(getattr(obj, attr)):
try:
return to_plain(getattr(obj, attr)())
except Exception:
pass
# Fallback: plain value
return obj
result = to_plain(pred)
# If the top level isn't a dict, wrap it so we always return one
return result if isinstance(result, dict) else {"value": result}
def dump_prediction(
obj: Any,
*,
exclude_private: bool = True,
exclude_keys: tuple[str, ...] = ("_lm_usage", "_inputs", "_completions"),
custom_predicate: Optional[Callable[[str, Any], bool]] = None,
) -> Any:
"""
Recursively convert a DSPy Prediction (or arbitrary nested structure)
into JSON-serializable Python types, excluding selected internal fields.
- exclude_private=True removes any dict keys starting with "_".
- exclude_keys removes specific keys regardless of exclude_private.
- custom_predicate(key, value) -> bool can veto inclusion of a field.
"""
# Avoid circulars / trivial primitives
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
# DSPy Prediction (or anything duck-typing a store-like interface)
try:
from dspy.primitives.prediction import Prediction # type: ignore
is_prediction = isinstance(obj, Prediction)
except Exception:
is_prediction = False
if is_prediction:
# Most DSPy predictions expose their data via an internal store.
store = getattr(obj, "_store", {}) or {}
out = {}
for k, v in store.items():
if k in exclude_keys:
continue
if exclude_private and isinstance(k, str) and k.startswith("_"):
continue
if custom_predicate and not custom_predicate(k, v):
continue
out[k] = dump_prediction(
v,
exclude_private=exclude_private,
exclude_keys=exclude_keys,
custom_predicate=custom_predicate,
)
return out
# Pydantic (v2 or v1) objects
try:
from pydantic import BaseModel # type: ignore
if isinstance(obj, BaseModel):
data = (
obj.model_dump() # v2
if hasattr(obj, "model_dump")
else obj.dict() # v1
)
return dump_prediction(
data,
exclude_private=exclude_private,
exclude_keys=exclude_keys,
custom_predicate=custom_predicate,
)
except Exception:
pass
# Dataclasses
if is_dataclass(obj):
return dump_prediction(
asdict(obj),
exclude_private=exclude_private,
exclude_keys=exclude_keys,
custom_predicate=custom_predicate,
)
# Mapping
if isinstance(obj, dict):
out = {}
for k, v in obj.items():
if isinstance(k, str):
if k in exclude_keys:
continue
if exclude_private and k.startswith("_"):
continue
if custom_predicate and not custom_predicate(k, v):
continue
out[k] = dump_prediction(
v,
exclude_private=exclude_private,
exclude_keys=exclude_keys,
custom_predicate=custom_predicate,
)
return out
# Sequence
if isinstance(obj, (list, tuple, set)):
return [
dump_prediction(
v,
exclude_private=exclude_private,
exclude_keys=exclude_keys,
custom_predicate=custom_predicate,
)
for v in obj
]
# Objects with to_dict()/dict()
for meth in ("to_dict", "dict"):
if hasattr(obj, meth) and callable(getattr(obj, meth)):
try:
data = getattr(obj, meth)()
return dump_prediction(
data,
exclude_private=exclude_private,
exclude_keys=exclude_keys,
custom_predicate=custom_predicate,
)
except Exception:
pass
# Fallback to JSON-safe conversion (datetime, numpy, decimal, etc.)
return _to_json_safe(obj)
# --- combining many predictions into one JSON blob ---
def dump_many_predictions(
preds: Iterable[Any],
*,
key_fn: Optional[Callable[[Any, int], Optional[str]]] = None,
**dump_kwargs,
) -> str:
"""
Dump many predictions into a single JSON string.
- By default returns a JSON array string.
- If key_fn is provided and returns a non-None key, we build a dict
mapping key -> dumped prediction.
- dump_kwargs are forwarded to dump_prediction (e.g., exclude_keys=...).
"""
dumped_list = [dump_prediction(p, **dump_kwargs) for p in preds]
if key_fn:
mapping = {}
for i, (p, d) in enumerate(zip(preds, dumped_list)):
k = key_fn(p, i)
if k is not None:
mapping[k] = d
return json.dumps(mapping, ensure_ascii=False, indent=2)
return json.dumps(dumped_list, ensure_ascii=False, indent=2)